Commit c2e19c97 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change.

PiperOrigin-RevId: 368935233
parent 127c9d80
......@@ -14,7 +14,10 @@
"""Contains definitions of dense prediction heads."""
from typing import List, Mapping, Optional, Tuple, Union
# Import libraries
import numpy as np
import tensorflow as tf
......@@ -25,22 +28,23 @@ from official.modeling import tf_utils
class RetinaNetHead(tf.keras.layers.Layer):
"""Creates a RetinaNet head."""
def __init__(self,
min_level,
max_level,
num_classes,
num_anchors_per_location,
num_convs=4,
num_filters=256,
attribute_heads=None,
use_separable_conv=False,
activation='relu',
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_regularizer=None,
bias_regularizer=None,
**kwargs):
def __init__(
self,
min_level: int,
max_level: int,
num_classes: int,
num_anchors_per_location: int,
num_convs: int = 4,
num_filters: int = 256,
attribute_heads: Mapping[str, Tuple[str, int]] = None,
use_separable_conv: bool = False,
activation: str = 'relu',
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
**kwargs):
"""Initializes a RetinaNet head.
Args:
......@@ -93,7 +97,7 @@ class RetinaNetHead(tf.keras.layers.Layer):
self._bn_axis = 1
self._activation = tf_utils.get_activation(activation)
def build(self, input_shape):
def build(self, input_shape: Union[tf.TensorShape, List[tf.TensorShape]]):
"""Creates the variables of the head."""
conv_op = (tf.keras.layers.SeparableConv2D
if self._config_dict['use_separable_conv']
......@@ -239,7 +243,7 @@ class RetinaNetHead(tf.keras.layers.Layer):
super(RetinaNetHead, self).build(input_shape)
def call(self, features):
def call(self, features: Mapping[str, tf.Tensor]):
"""Forward pass of the RetinaNet head.
Args:
......@@ -325,20 +329,21 @@ class RetinaNetHead(tf.keras.layers.Layer):
class RPNHead(tf.keras.layers.Layer):
"""Creates a Region Proposal Network (RPN) head."""
def __init__(self,
min_level,
max_level,
num_anchors_per_location,
num_convs=1,
num_filters=256,
use_separable_conv=False,
activation='relu',
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_regularizer=None,
bias_regularizer=None,
**kwargs):
def __init__(
self,
min_level: int,
max_level: int,
num_anchors_per_location: int,
num_convs: int = 1,
num_filters: int = 256,
use_separable_conv: bool = False,
activation: str = 'relu',
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
**kwargs):
"""Initializes a Region Proposal Network head.
Args:
......@@ -457,7 +462,7 @@ class RPNHead(tf.keras.layers.Layer):
super(RPNHead, self).build(input_shape)
def call(self, features):
def call(self, features: Mapping[str, tf.Tensor]):
"""Forward pass of the RPN head.
Args:
......
......@@ -14,6 +14,7 @@
"""Contains definitions of instance prediction heads."""
from typing import List, Union, Optional
# Import libraries
import tensorflow as tf
......@@ -24,20 +25,21 @@ from official.modeling import tf_utils
class DetectionHead(tf.keras.layers.Layer):
"""Creates a detection head."""
def __init__(self,
num_classes,
num_convs=0,
num_filters=256,
use_separable_conv=False,
num_fcs=2,
fc_dims=1024,
activation='relu',
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_regularizer=None,
bias_regularizer=None,
**kwargs):
def __init__(
self,
num_classes: int,
num_convs: int = 0,
num_filters: int = 256,
use_separable_conv: bool = False,
num_fcs: int = 2,
fc_dims: int = 1024,
activation: str = 'relu',
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
**kwargs):
"""Initializes a detection head.
Args:
......@@ -85,7 +87,7 @@ class DetectionHead(tf.keras.layers.Layer):
self._bn_axis = 1
self._activation = tf_utils.get_activation(activation)
def build(self, input_shape):
def build(self, input_shape: Union[tf.TensorShape, List[tf.TensorShape]]):
"""Creates the variables of the head."""
conv_op = (tf.keras.layers.SeparableConv2D
if self._config_dict['use_separable_conv']
......@@ -163,7 +165,7 @@ class DetectionHead(tf.keras.layers.Layer):
super(DetectionHead, self).build(input_shape)
def call(self, inputs, training=None):
def call(self, inputs: tf.Tensor, training: bool = None):
"""Forward pass of box and class branches for the Mask-RCNN model.
Args:
......@@ -211,20 +213,21 @@ class DetectionHead(tf.keras.layers.Layer):
class MaskHead(tf.keras.layers.Layer):
"""Creates a mask head."""
def __init__(self,
num_classes,
upsample_factor=2,
num_convs=4,
num_filters=256,
use_separable_conv=False,
activation='relu',
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_regularizer=None,
bias_regularizer=None,
class_agnostic=False,
**kwargs):
def __init__(
self,
num_classes: int,
upsample_factor: int = 2,
num_convs: int = 4,
num_filters: int = 256,
use_separable_conv: bool = False,
activation: str = 'relu',
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
class_agnostic: bool = False,
**kwargs):
"""Initializes a mask head.
Args:
......@@ -272,7 +275,7 @@ class MaskHead(tf.keras.layers.Layer):
self._bn_axis = 1
self._activation = tf_utils.get_activation(activation)
def build(self, input_shape):
def build(self, input_shape: Union[tf.TensorShape, List[tf.TensorShape]]):
"""Creates the variables of the head."""
conv_op = (tf.keras.layers.SeparableConv2D
if self._config_dict['use_separable_conv']
......@@ -364,7 +367,7 @@ class MaskHead(tf.keras.layers.Layer):
super(MaskHead, self).build(input_shape)
def call(self, inputs, training=None):
def call(self, inputs: List[tf.Tensor], training: bool = None):
"""Forward pass of mask branch for the Mask-RCNN model.
Args:
......
......@@ -13,7 +13,7 @@
# limitations under the License.
"""Contains definitions of segmentation heads."""
from typing import List, Union, Optional, Mapping
import tensorflow as tf
from official.modeling import tf_utils
......@@ -25,23 +25,24 @@ from official.vision.beta.ops import spatial_transform_ops
class SegmentationHead(tf.keras.layers.Layer):
"""Creates a segmentation head."""
def __init__(self,
num_classes,
level,
num_convs=2,
num_filters=256,
prediction_kernel_size=1,
upsample_factor=1,
feature_fusion=None,
low_level=2,
low_level_num_filters=48,
activation='relu',
use_sync_bn=False,
norm_momentum=0.99,
norm_epsilon=0.001,
kernel_regularizer=None,
bias_regularizer=None,
**kwargs):
def __init__(
self,
num_classes: int,
level: Union[int, str],
num_convs: int = 2,
num_filters: int = 256,
prediction_kernel_size: int = 1,
upsample_factor: int = 1,
feature_fusion: Optional[str] = None,
low_level: int = 2,
low_level_num_filters: int = 48,
activation: str = 'relu',
use_sync_bn: bool = False,
norm_momentum: float = 0.99,
norm_epsilon: float = 0.001,
kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
**kwargs):
"""Initializes a segmentation head.
Args:
......@@ -101,7 +102,7 @@ class SegmentationHead(tf.keras.layers.Layer):
self._bn_axis = 1
self._activation = tf_utils.get_activation(activation)
def build(self, input_shape):
def build(self, input_shape: Union[tf.TensorShape, List[tf.TensorShape]]):
"""Creates the variables of the segmentation head."""
conv_op = tf.keras.layers.Conv2D
conv_kwargs = {
......@@ -159,7 +160,8 @@ class SegmentationHead(tf.keras.layers.Layer):
super(SegmentationHead, self).build(input_shape)
def call(self, backbone_output, decoder_output):
def call(self, backbone_output: Mapping[str, tf.Tensor],
decoder_output: Mapping[str, tf.Tensor]):
"""Forward pass of the segmentation head.
Args:
......
......@@ -25,8 +25,8 @@ class BoxSampler(tf.keras.layers.Layer):
"""Creates a BoxSampler to sample positive and negative boxes."""
def __init__(self,
num_samples=512,
foreground_fraction=0.25,
num_samples: int = 512,
foreground_fraction: float = 0.25,
**kwargs):
"""Initializes a box sampler.
......@@ -42,7 +42,8 @@ class BoxSampler(tf.keras.layers.Layer):
}
super(BoxSampler, self).__init__(**kwargs)
def call(self, positive_matches, negative_matches, ignored_matches):
def call(self, positive_matches: tf.Tensor, negative_matches: tf.Tensor,
ignored_matches: tf.Tensor):
"""Samples and selects positive and negative instances.
Args:
......
......@@ -13,7 +13,7 @@
# limitations under the License.
"""Contains definitions of generators to generate the final detections."""
from typing import Optional, Mapping
# Import libraries
import tensorflow as tf
......@@ -21,13 +21,14 @@ from official.vision.beta.ops import box_ops
from official.vision.beta.ops import nms
def _generate_detections_v1(boxes,
scores,
attributes=None,
pre_nms_top_k=5000,
pre_nms_score_threshold=0.05,
nms_iou_threshold=0.5,
max_num_detections=100):
def _generate_detections_v1(boxes: tf.Tensor,
scores: tf.Tensor,
attributes: Optional[Mapping[str,
tf.Tensor]] = None,
pre_nms_top_k: int = 5000,
pre_nms_score_threshold: float = 0.05,
nms_iou_threshold: float = 0.5,
max_num_detections: int = 100):
"""Generates the final detections given the model outputs.
The implementation unrolls the batch dimension and process images one by one.
......@@ -117,13 +118,14 @@ def _generate_detections_v1(boxes,
return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections, nmsed_attributes
def _generate_detections_per_image(boxes,
scores,
attributes=None,
pre_nms_top_k=5000,
pre_nms_score_threshold=0.05,
nms_iou_threshold=0.5,
max_num_detections=100):
def _generate_detections_per_image(
boxes: tf.Tensor,
scores: tf.Tensor,
attributes: Optional[Mapping[str, tf.Tensor]] = None,
pre_nms_top_k: int = 5000,
pre_nms_score_threshold: float = 0.05,
nms_iou_threshold: float = 0.5,
max_num_detections: int = 100):
"""Generates the final detections per image given the model outputs.
Args:
......@@ -225,7 +227,7 @@ def _generate_detections_per_image(boxes,
return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections, nmsed_attributes
def _select_top_k_scores(scores_in, pre_nms_num_detections):
def _select_top_k_scores(scores_in: tf.Tensor, pre_nms_num_detections: int):
"""Selects top_k scores and indices for each class.
Args:
......@@ -255,12 +257,12 @@ def _select_top_k_scores(scores_in, pre_nms_num_detections):
[0, 2, 1]), tf.transpose(top_k_indices, [0, 2, 1])
def _generate_detections_v2(boxes,
scores,
pre_nms_top_k=5000,
pre_nms_score_threshold=0.05,
nms_iou_threshold=0.5,
max_num_detections=100):
def _generate_detections_v2(boxes: tf.Tensor,
scores: tf.Tensor,
pre_nms_top_k: int = 5000,
pre_nms_score_threshold: float = 0.05,
nms_iou_threshold: float = 0.5,
max_num_detections: int = 100):
"""Generates the final detections given the model outputs.
This implementation unrolls classes dimension while using the tf.while_loop
......@@ -337,11 +339,10 @@ def _generate_detections_v2(boxes,
return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
def _generate_detections_batched(boxes,
scores,
pre_nms_score_threshold,
nms_iou_threshold,
max_num_detections):
def _generate_detections_batched(boxes: tf.Tensor, scores: tf.Tensor,
pre_nms_score_threshold: float,
nms_iou_threshold: float,
max_num_detections: int):
"""Generates detected boxes with scores and classes for one-stage detector.
The function takes output of multi-level ConvNets and anchor boxes and
......@@ -393,12 +394,12 @@ class DetectionGenerator(tf.keras.layers.Layer):
"""Generates the final detected boxes with scores and classes."""
def __init__(self,
apply_nms=True,
pre_nms_top_k=5000,
pre_nms_score_threshold=0.05,
nms_iou_threshold=0.5,
max_num_detections=100,
use_batched_nms=False,
apply_nms: bool = True,
pre_nms_top_k: int = 5000,
pre_nms_score_threshold: float = 0.05,
nms_iou_threshold: float = 0.5,
max_num_detections: int = 100,
use_batched_nms: bool = False,
**kwargs):
"""Initializes a detection generator.
......@@ -427,11 +428,8 @@ class DetectionGenerator(tf.keras.layers.Layer):
}
super(DetectionGenerator, self).__init__(**kwargs)
def __call__(self,
raw_boxes,
raw_scores,
anchor_boxes,
image_shape):
def __call__(self, raw_boxes: tf.Tensor, raw_scores: tf.Tensor,
anchor_boxes: tf.Tensor, image_shape: tf.Tensor):
"""Generates final detections.
Args:
......@@ -546,12 +544,12 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
"""Generates detected boxes with scores and classes for one-stage detector."""
def __init__(self,
apply_nms=True,
pre_nms_top_k=5000,
pre_nms_score_threshold=0.05,
nms_iou_threshold=0.5,
max_num_detections=100,
use_batched_nms=False,
apply_nms: bool = True,
pre_nms_top_k: int = 5000,
pre_nms_score_threshold: float = 0.05,
nms_iou_threshold: float = 0.5,
max_num_detections: int = 100,
use_batched_nms: bool = False,
**kwargs):
"""Initializes a multi-level detection generator.
......@@ -581,11 +579,11 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
super(MultilevelDetectionGenerator, self).__init__(**kwargs)
def __call__(self,
raw_boxes,
raw_scores,
anchor_boxes,
image_shape,
raw_attributes=None):
raw_boxes: Mapping[str, tf.Tensor],
raw_scores: Mapping[str, tf.Tensor],
anchor_boxes: tf.Tensor,
image_shape: tf.Tensor,
raw_attributes: Mapping[str, tf.Tensor] = None):
"""Generates final detections.
Args:
......@@ -600,11 +598,10 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
image_shape: A `tf.Tensor` of shape of [batch_size, 2] storing the image
height and width w.r.t. the scaled image, i.e. the same image space as
`box_outputs` and `anchor_boxes`.
raw_attributes: If not None, a `dict` of
(attribute_name, attribute_prediction) pairs. `attribute_prediction`
is a dict that contains keys representing FPN levels and values
representing tenors of shape `[batch, feature_h, feature_w,
num_anchors * attribute_size]`.
raw_attributes: If not None, a `dict` of (attribute_name,
attribute_prediction) pairs. `attribute_prediction` is a dict that
contains keys representing FPN levels and values representing tenors of
shape `[batch, feature_h, feature_w, num_anchors * attribute_size]`.
Returns:
If `apply_nms` = True, the return is a dictionary with keys:
......
......@@ -20,13 +20,13 @@ import tensorflow as tf
from official.vision.beta.ops import spatial_transform_ops
def _sample_and_crop_foreground_masks(candidate_rois,
candidate_gt_boxes,
candidate_gt_classes,
candidate_gt_indices,
gt_masks,
num_sampled_masks=128,
mask_target_size=28):
def _sample_and_crop_foreground_masks(candidate_rois: tf.Tensor,
candidate_gt_boxes: tf.Tensor,
candidate_gt_classes: tf.Tensor,
candidate_gt_indices: tf.Tensor,
gt_masks: tf.Tensor,
num_sampled_masks: int = 128,
mask_target_size: int = 28):
"""Samples and creates cropped foreground masks for training.
Args:
......@@ -104,22 +104,16 @@ def _sample_and_crop_foreground_masks(candidate_rois,
class MaskSampler(tf.keras.layers.Layer):
"""Samples and creates mask training targets."""
def __init__(self,
mask_target_size,
num_sampled_masks,
**kwargs):
def __init__(self, mask_target_size: int, num_sampled_masks: int, **kwargs):
self._config_dict = {
'mask_target_size': mask_target_size,
'num_sampled_masks': num_sampled_masks,
}
super(MaskSampler, self).__init__(**kwargs)
def call(self,
candidate_rois,
candidate_gt_boxes,
candidate_gt_classes,
candidate_gt_indices,
gt_masks):
def call(self, candidate_rois: tf.Tensor, candidate_gt_boxes: tf.Tensor,
candidate_gt_classes: tf.Tensor, candidate_gt_indices: tf.Tensor,
gt_masks: tf.Tensor):
"""Samples and creates mask targets for training.
Args:
......
......@@ -14,6 +14,7 @@
"""Contains definitions of ROI aligner."""
from typing import Mapping
import tensorflow as tf
from official.vision.beta.ops import spatial_transform_ops
......@@ -23,10 +24,7 @@ from official.vision.beta.ops import spatial_transform_ops
class MultilevelROIAligner(tf.keras.layers.Layer):
"""Performs ROIAlign for the second stage processing."""
def __init__(self,
crop_size=7,
sample_offset=0.5,
**kwargs):
def __init__(self, crop_size: int = 7, sample_offset: float = 0.5, **kwargs):
"""Initializes a ROI aligner.
Args:
......@@ -40,7 +38,10 @@ class MultilevelROIAligner(tf.keras.layers.Layer):
}
super(MultilevelROIAligner, self).__init__(**kwargs)
def call(self, features, boxes, training=None):
def call(self,
features: Mapping[str, tf.Tensor],
boxes: tf.Tensor,
training: bool = None):
"""Generates ROIs.
Args:
......
......@@ -13,7 +13,7 @@
# limitations under the License.
"""Contains definitions of ROI generator."""
from typing import Optional, Mapping
# Import libraries
import tensorflow as tf
......@@ -21,19 +21,19 @@ from official.vision.beta.ops import box_ops
from official.vision.beta.ops import nms
def _multilevel_propose_rois(raw_boxes,
raw_scores,
anchor_boxes,
image_shape,
pre_nms_top_k=2000,
pre_nms_score_threshold=0.0,
pre_nms_min_size_threshold=0.0,
nms_iou_threshold=0.7,
num_proposals=1000,
use_batched_nms=False,
decode_boxes=True,
clip_boxes=True,
apply_sigmoid_to_score=True):
def _multilevel_propose_rois(raw_boxes: Mapping[str, tf.Tensor],
raw_scores: Mapping[str, tf.Tensor],
anchor_boxes: Mapping[str, tf.Tensor],
image_shape: tf.Tensor,
pre_nms_top_k: int = 2000,
pre_nms_score_threshold: float = 0.0,
pre_nms_min_size_threshold: float = 0.0,
nms_iou_threshold: float = 0.7,
num_proposals: int = 1000,
use_batched_nms: bool = False,
decode_boxes: bool = True,
clip_boxes: bool = True,
apply_sigmoid_to_score: bool = True):
"""Proposes RoIs given a group of candidates from different FPN levels.
The following describes the steps:
......@@ -181,17 +181,17 @@ class MultilevelROIGenerator(tf.keras.layers.Layer):
"""Proposes RoIs for the second stage processing."""
def __init__(self,
pre_nms_top_k=2000,
pre_nms_score_threshold=0.0,
pre_nms_min_size_threshold=0.0,
nms_iou_threshold=0.7,
num_proposals=1000,
test_pre_nms_top_k=1000,
test_pre_nms_score_threshold=0.0,
test_pre_nms_min_size_threshold=0.0,
test_nms_iou_threshold=0.7,
test_num_proposals=1000,
use_batched_nms=False,
pre_nms_top_k: int = 2000,
pre_nms_score_threshold: float = 0.0,
pre_nms_min_size_threshold: float = 0.0,
nms_iou_threshold: float = 0.7,
num_proposals: int = 1000,
test_pre_nms_top_k: int = 1000,
test_pre_nms_score_threshold: float = 0.0,
test_pre_nms_min_size_threshold: float = 0.0,
test_nms_iou_threshold: float = 0.7,
test_num_proposals: int = 1000,
use_batched_nms: bool = False,
**kwargs):
"""Initializes a ROI generator.
......@@ -240,11 +240,11 @@ class MultilevelROIGenerator(tf.keras.layers.Layer):
super(MultilevelROIGenerator, self).__init__(**kwargs)
def call(self,
raw_boxes,
raw_scores,
anchor_boxes,
image_shape,
training=None):
raw_boxes: Mapping[str, tf.Tensor],
raw_scores: Mapping[str, tf.Tensor],
anchor_boxes: Mapping[str, tf.Tensor],
image_shape: tf.Tensor,
training: Optional[bool] = None):
"""Proposes RoIs given a group of candidates from different FPN levels.
The following describes the steps:
......
......@@ -13,7 +13,6 @@
# limitations under the License.
"""Contains definitions of ROI sampler."""
# Import libraries
import tensorflow as tf
......@@ -26,12 +25,12 @@ class ROISampler(tf.keras.layers.Layer):
"""Samples ROIs and assigns targets to the sampled ROIs."""
def __init__(self,
mix_gt_boxes=True,
num_sampled_rois=512,
foreground_fraction=0.25,
foreground_iou_threshold=0.5,
background_iou_high_threshold=0.5,
background_iou_low_threshold=0,
mix_gt_boxes: bool = True,
num_sampled_rois: int = 512,
foreground_fraction: float = 0.25,
foreground_iou_threshold: float = 0.5,
background_iou_high_threshold: float = 0.5,
background_iou_low_threshold: float = 0,
**kwargs):
"""Initializes a ROI sampler.
......@@ -73,7 +72,7 @@ class ROISampler(tf.keras.layers.Layer):
num_sampled_rois, foreground_fraction)
super(ROISampler, self).__init__(**kwargs)
def call(self, boxes, gt_boxes, gt_classes):
def call(self, boxes: tf.Tensor, gt_boxes: tf.Tensor, gt_classes: tf.Tensor):
"""Assigns the proposals with groundtruth classes and performs subsmpling.
Given `proposed_boxes`, `gt_boxes`, and `gt_classes`, the function uses the
......
......@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Image classification task definition."""
from typing import Any, Optional, List, Tuple
from absl import logging
import tensorflow as tf
......@@ -51,7 +51,7 @@ class ImageClassificationTask(base_task.Task):
return model
def initialize(self, model: tf.keras.Model):
"""Loading pretrained checkpoint."""
"""Loads pretrained checkpoint."""
if not self.task_config.init_checkpoint:
return
......@@ -75,7 +75,9 @@ class ImageClassificationTask(base_task.Task):
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
def build_inputs(self, params, input_context=None):
def build_inputs(self,
params: exp_cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None):
"""Builds classification input."""
num_classes = self.task_config.model.num_classes
......@@ -112,13 +114,16 @@ class ImageClassificationTask(base_task.Task):
return dataset
def build_losses(self, labels, model_outputs, aux_losses=None):
"""Sparse categorical cross entropy loss.
def build_losses(self,
labels: tf.Tensor,
model_outputs: tf.Tensor,
aux_losses: Optional[Any] = None):
"""Builds sparse categorical cross entropy loss.
Args:
labels: labels.
labels: Input groundtruth labels.
model_outputs: Output logits of the classifier.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.
aux_losses: The auxiliarly loss tensors, i.e. `losses` in tf.keras.Model.
Returns:
The total loss tensor.
......@@ -140,7 +145,7 @@ class ImageClassificationTask(base_task.Task):
return total_loss
def build_metrics(self, training=True):
def build_metrics(self, training: bool = True):
"""Gets streaming metrics for training/validation."""
k = self.task_config.evaluation.top_k
if self.task_config.losses.one_hot:
......@@ -155,14 +160,18 @@ class ImageClassificationTask(base_task.Task):
k=k, name='top_{}_accuracy'.format(k))]
return metrics
def train_step(self, inputs, model, optimizer, metrics=None):
def train_step(self,
inputs: Tuple[Any, Any],
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
metrics: Optional[List[Any]] = None):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
inputs: A tuple of of input tensors of (features, labels).
model: A tf.keras.Model instance.
optimizer: The optimizer for this training step.
metrics: A nested structure of metrics objects.
Returns:
A dictionary of logs.
......@@ -209,13 +218,16 @@ class ImageClassificationTask(base_task.Task):
logs.update({m.name: m.result() for m in model.metrics})
return logs
def validation_step(self, inputs, model, metrics=None):
"""Validatation step.
def validation_step(self,
inputs: Tuple[Any, Any],
model: tf.keras.Model,
metrics: Optional[List[Any]] = None):
"""Runs validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
inputs: A tuple of of input tensors of (features, labels).
model: A tf.keras.Model instance.
metrics: A nested structure of metrics objects.
Returns:
A dictionary of logs.
......@@ -237,6 +249,6 @@ class ImageClassificationTask(base_task.Task):
logs.update({m.name: m.result() for m in model.metrics})
return logs
def inference_step(self, inputs, model):
def inference_step(self, inputs: tf.Tensor, model: tf.keras.Model):
"""Performs the forward step."""
return model(inputs, training=False)
......@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""RetinaNet task definition."""
from typing import Any, Optional, List, Tuple, Mapping
from absl import logging
import tensorflow as tf
......@@ -30,7 +30,8 @@ from official.vision.beta.losses import maskrcnn_losses
from official.vision.beta.modeling import factory
def zero_out_disallowed_class_ids(batch_class_ids, allowed_class_ids):
def zero_out_disallowed_class_ids(batch_class_ids: tf.Tensor,
allowed_class_ids: List[int]):
"""Zero out IDs of classes not in allowed_class_ids.
Args:
......@@ -106,7 +107,9 @@ class MaskRCNNTask(base_task.Task):
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
def build_inputs(self, params, input_context=None):
def build_inputs(self,
params: exp_cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None):
"""Build input dataset."""
decoder_cfg = params.decoder.get()
if params.decoder.type == 'simple_decoder':
......@@ -152,7 +155,10 @@ class MaskRCNNTask(base_task.Task):
return dataset
def build_losses(self, outputs, labels, aux_losses=None):
def build_losses(self,
outputs: Mapping[str, Any],
labels: Mapping[str, Any],
aux_losses: Optional[Any] = None):
"""Build Mask R-CNN losses."""
params = self.task_config
......@@ -218,7 +224,7 @@ class MaskRCNNTask(base_task.Task):
}
return losses
def build_metrics(self, training=True):
def build_metrics(self, training: bool = True):
"""Build detection metrics."""
metrics = []
if training:
......@@ -242,7 +248,11 @@ class MaskRCNNTask(base_task.Task):
return metrics
def train_step(self, inputs, model, optimizer, metrics=None):
def train_step(self,
inputs: Tuple[Any, Any],
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
metrics: Optional[List[Any]] = None):
"""Does forward and backward.
Args:
......@@ -294,7 +304,10 @@ class MaskRCNNTask(base_task.Task):
return logs
def validation_step(self, inputs, model, metrics=None):
def validation_step(self,
inputs: Tuple[Any, Any],
model: tf.keras.Model,
metrics: Optional[List[Any]] = None):
"""Validatation step.
Args:
......
......@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""RetinaNet task definition."""
from typing import Any, Optional, List, Tuple, Mapping
from absl import logging
import tensorflow as tf
......@@ -84,7 +84,9 @@ class RetinaNetTask(base_task.Task):
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
def build_inputs(self, params, input_context=None):
def build_inputs(self,
params: exp_cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None):
"""Build input dataset."""
if params.tfds_name:
......@@ -131,7 +133,10 @@ class RetinaNetTask(base_task.Task):
return dataset
def build_losses(self, outputs, labels, aux_losses=None):
def build_losses(self,
outputs: Mapping[str, Any],
labels: Mapping[str, Any],
aux_losses: Optional[Any] = None):
"""Build RetinaNet losses."""
params = self.task_config
cls_loss_fn = keras_cv.losses.FocalLoss(
......@@ -172,7 +177,7 @@ class RetinaNetTask(base_task.Task):
return total_loss, cls_loss, box_loss, model_loss
def build_metrics(self, training=True):
def build_metrics(self, training: bool = True):
"""Build detection metrics."""
metrics = []
metric_names = ['total_loss', 'cls_loss', 'box_loss', 'model_loss']
......@@ -190,7 +195,11 @@ class RetinaNetTask(base_task.Task):
return metrics
def train_step(self, inputs, model, optimizer, metrics=None):
def train_step(self,
inputs: Tuple[Any, Any],
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
metrics: Optional[List[Any]] = None):
"""Does forward and backward.
Args:
......@@ -241,7 +250,10 @@ class RetinaNetTask(base_task.Task):
return logs
def validation_step(self, inputs, model, metrics=None):
def validation_step(self,
inputs: Tuple[Any, Any],
model: tf.keras.Model,
metrics: Optional[List[Any]] = None):
"""Validatation step.
Args:
......
......@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Image segmentation task definition."""
from typing import Any, Optional, List, Tuple, Mapping, Union
from absl import logging
import tensorflow as tf
......@@ -79,7 +79,9 @@ class SemanticSegmentationTask(base_task.Task):
logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file)
def build_inputs(self, params, input_context=None):
def build_inputs(self,
params: exp_cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None):
"""Builds classification input."""
ignore_label = self.task_config.losses.ignore_label
......@@ -114,7 +116,10 @@ class SemanticSegmentationTask(base_task.Task):
return dataset
def build_losses(self, labels, model_outputs, aux_losses=None):
def build_losses(self,
labels: Mapping[str, tf.Tensor],
model_outputs: Union[Mapping[str, tf.Tensor], tf.Tensor],
aux_losses: Optional[Any] = None):
"""Segmentation loss.
Args:
......@@ -140,7 +145,7 @@ class SemanticSegmentationTask(base_task.Task):
return total_loss
def build_metrics(self, training=True):
def build_metrics(self, training: bool = True):
"""Gets streaming metrics for training/validation."""
metrics = []
if training and self.task_config.evaluation.report_train_mean_iou:
......@@ -159,7 +164,11 @@ class SemanticSegmentationTask(base_task.Task):
return metrics
def train_step(self, inputs, model, optimizer, metrics=None):
def train_step(self,
inputs: Tuple[Any, Any],
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
metrics: Optional[List[Any]] = None):
"""Does forward and backward.
Args:
......@@ -214,7 +223,10 @@ class SemanticSegmentationTask(base_task.Task):
return logs
def validation_step(self, inputs, model, metrics=None):
def validation_step(self,
inputs: Tuple[Any, Any],
model: tf.keras.Model,
metrics: Optional[List[Any]] = None):
"""Validatation step.
Args:
......@@ -251,7 +263,7 @@ class SemanticSegmentationTask(base_task.Task):
return logs
def inference_step(self, inputs, model):
def inference_step(self, inputs: tf.Tensor, model: tf.keras.Model):
"""Performs the forward step."""
return model(inputs, training=False)
......
......@@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Video classification task definition."""
from typing import Any, Optional, List, Tuple
from absl import logging
import tensorflow as tf
from official.core import base_task
......@@ -68,7 +69,9 @@ class VideoClassificationTask(base_task.Task):
tf.io.VarLenFeature(dtype=tf.float32))
return decoder.decode
def build_inputs(self, params: exp_cfg.DataConfig, input_context=None):
def build_inputs(self,
params: exp_cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None):
"""Builds classification input."""
parser = video_input.Parser(input_params=params)
......@@ -85,7 +88,10 @@ class VideoClassificationTask(base_task.Task):
return dataset
def build_losses(self, labels, model_outputs, aux_losses=None):
def build_losses(self,
labels: Any,
model_outputs: Any,
aux_losses: Optional[Any] = None):
"""Sparse categorical cross entropy loss.
Args:
......@@ -132,7 +138,7 @@ class VideoClassificationTask(base_task.Task):
return all_losses
def build_metrics(self, training=True):
def build_metrics(self, training: bool = True):
"""Gets streaming metrics for training/validation."""
if self.task_config.losses.one_hot:
metrics = [
......@@ -168,7 +174,8 @@ class VideoClassificationTask(base_task.Task):
]
return metrics
def process_metrics(self, metrics, labels, model_outputs):
def process_metrics(self, metrics: List[Any], labels: Any,
model_outputs: Any):
"""Process and update metrics.
Called when using custom training loop API.
......@@ -183,7 +190,11 @@ class VideoClassificationTask(base_task.Task):
for metric in metrics:
metric.update_state(labels, model_outputs)
def train_step(self, inputs, model, optimizer, metrics=None):
def train_step(self,
inputs: Tuple[Any, Any],
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
metrics: Optional[List[Any]] = None):
"""Does forward and backward.
Args:
......@@ -240,7 +251,10 @@ class VideoClassificationTask(base_task.Task):
logs.update({m.name: m.result() for m in model.metrics})
return logs
def validation_step(self, inputs, model, metrics=None):
def validation_step(self,
inputs: Tuple[Any, Any],
model: tf.keras.Model,
metrics: Optional[List[Any]] = None):
"""Validatation step.
Args:
......@@ -266,7 +280,7 @@ class VideoClassificationTask(base_task.Task):
logs.update({m.name: m.result() for m in model.metrics})
return logs
def inference_step(self, features, model):
def inference_step(self, features: tf.Tensor, model: tf.keras.Model):
"""Performs the forward step."""
outputs = model(features, training=False)
if self.task_config.train_data.is_multilabel:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment