Commit c2e19c97 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change.

PiperOrigin-RevId: 368935233
parent 127c9d80
...@@ -14,7 +14,10 @@ ...@@ -14,7 +14,10 @@
"""Contains definitions of dense prediction heads.""" """Contains definitions of dense prediction heads."""
from typing import List, Mapping, Optional, Tuple, Union
# Import libraries # Import libraries
import numpy as np import numpy as np
import tensorflow as tf import tensorflow as tf
...@@ -25,21 +28,22 @@ from official.modeling import tf_utils ...@@ -25,21 +28,22 @@ from official.modeling import tf_utils
class RetinaNetHead(tf.keras.layers.Layer): class RetinaNetHead(tf.keras.layers.Layer):
"""Creates a RetinaNet head.""" """Creates a RetinaNet head."""
def __init__(self, def __init__(
min_level, self,
max_level, min_level: int,
num_classes, max_level: int,
num_anchors_per_location, num_classes: int,
num_convs=4, num_anchors_per_location: int,
num_filters=256, num_convs: int = 4,
attribute_heads=None, num_filters: int = 256,
use_separable_conv=False, attribute_heads: Mapping[str, Tuple[str, int]] = None,
activation='relu', use_separable_conv: bool = False,
use_sync_bn=False, activation: str = 'relu',
norm_momentum=0.99, use_sync_bn: bool = False,
norm_epsilon=0.001, norm_momentum: float = 0.99,
kernel_regularizer=None, norm_epsilon: float = 0.001,
bias_regularizer=None, kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
**kwargs): **kwargs):
"""Initializes a RetinaNet head. """Initializes a RetinaNet head.
...@@ -93,7 +97,7 @@ class RetinaNetHead(tf.keras.layers.Layer): ...@@ -93,7 +97,7 @@ class RetinaNetHead(tf.keras.layers.Layer):
self._bn_axis = 1 self._bn_axis = 1
self._activation = tf_utils.get_activation(activation) self._activation = tf_utils.get_activation(activation)
def build(self, input_shape): def build(self, input_shape: Union[tf.TensorShape, List[tf.TensorShape]]):
"""Creates the variables of the head.""" """Creates the variables of the head."""
conv_op = (tf.keras.layers.SeparableConv2D conv_op = (tf.keras.layers.SeparableConv2D
if self._config_dict['use_separable_conv'] if self._config_dict['use_separable_conv']
...@@ -239,7 +243,7 @@ class RetinaNetHead(tf.keras.layers.Layer): ...@@ -239,7 +243,7 @@ class RetinaNetHead(tf.keras.layers.Layer):
super(RetinaNetHead, self).build(input_shape) super(RetinaNetHead, self).build(input_shape)
def call(self, features): def call(self, features: Mapping[str, tf.Tensor]):
"""Forward pass of the RetinaNet head. """Forward pass of the RetinaNet head.
Args: Args:
...@@ -325,19 +329,20 @@ class RetinaNetHead(tf.keras.layers.Layer): ...@@ -325,19 +329,20 @@ class RetinaNetHead(tf.keras.layers.Layer):
class RPNHead(tf.keras.layers.Layer): class RPNHead(tf.keras.layers.Layer):
"""Creates a Region Proposal Network (RPN) head.""" """Creates a Region Proposal Network (RPN) head."""
def __init__(self, def __init__(
min_level, self,
max_level, min_level: int,
num_anchors_per_location, max_level: int,
num_convs=1, num_anchors_per_location: int,
num_filters=256, num_convs: int = 1,
use_separable_conv=False, num_filters: int = 256,
activation='relu', use_separable_conv: bool = False,
use_sync_bn=False, activation: str = 'relu',
norm_momentum=0.99, use_sync_bn: bool = False,
norm_epsilon=0.001, norm_momentum: float = 0.99,
kernel_regularizer=None, norm_epsilon: float = 0.001,
bias_regularizer=None, kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
**kwargs): **kwargs):
"""Initializes a Region Proposal Network head. """Initializes a Region Proposal Network head.
...@@ -457,7 +462,7 @@ class RPNHead(tf.keras.layers.Layer): ...@@ -457,7 +462,7 @@ class RPNHead(tf.keras.layers.Layer):
super(RPNHead, self).build(input_shape) super(RPNHead, self).build(input_shape)
def call(self, features): def call(self, features: Mapping[str, tf.Tensor]):
"""Forward pass of the RPN head. """Forward pass of the RPN head.
Args: Args:
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
"""Contains definitions of instance prediction heads.""" """Contains definitions of instance prediction heads."""
from typing import List, Union, Optional
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
...@@ -24,19 +25,20 @@ from official.modeling import tf_utils ...@@ -24,19 +25,20 @@ from official.modeling import tf_utils
class DetectionHead(tf.keras.layers.Layer): class DetectionHead(tf.keras.layers.Layer):
"""Creates a detection head.""" """Creates a detection head."""
def __init__(self, def __init__(
num_classes, self,
num_convs=0, num_classes: int,
num_filters=256, num_convs: int = 0,
use_separable_conv=False, num_filters: int = 256,
num_fcs=2, use_separable_conv: bool = False,
fc_dims=1024, num_fcs: int = 2,
activation='relu', fc_dims: int = 1024,
use_sync_bn=False, activation: str = 'relu',
norm_momentum=0.99, use_sync_bn: bool = False,
norm_epsilon=0.001, norm_momentum: float = 0.99,
kernel_regularizer=None, norm_epsilon: float = 0.001,
bias_regularizer=None, kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
**kwargs): **kwargs):
"""Initializes a detection head. """Initializes a detection head.
...@@ -85,7 +87,7 @@ class DetectionHead(tf.keras.layers.Layer): ...@@ -85,7 +87,7 @@ class DetectionHead(tf.keras.layers.Layer):
self._bn_axis = 1 self._bn_axis = 1
self._activation = tf_utils.get_activation(activation) self._activation = tf_utils.get_activation(activation)
def build(self, input_shape): def build(self, input_shape: Union[tf.TensorShape, List[tf.TensorShape]]):
"""Creates the variables of the head.""" """Creates the variables of the head."""
conv_op = (tf.keras.layers.SeparableConv2D conv_op = (tf.keras.layers.SeparableConv2D
if self._config_dict['use_separable_conv'] if self._config_dict['use_separable_conv']
...@@ -163,7 +165,7 @@ class DetectionHead(tf.keras.layers.Layer): ...@@ -163,7 +165,7 @@ class DetectionHead(tf.keras.layers.Layer):
super(DetectionHead, self).build(input_shape) super(DetectionHead, self).build(input_shape)
def call(self, inputs, training=None): def call(self, inputs: tf.Tensor, training: bool = None):
"""Forward pass of box and class branches for the Mask-RCNN model. """Forward pass of box and class branches for the Mask-RCNN model.
Args: Args:
...@@ -211,19 +213,20 @@ class DetectionHead(tf.keras.layers.Layer): ...@@ -211,19 +213,20 @@ class DetectionHead(tf.keras.layers.Layer):
class MaskHead(tf.keras.layers.Layer): class MaskHead(tf.keras.layers.Layer):
"""Creates a mask head.""" """Creates a mask head."""
def __init__(self, def __init__(
num_classes, self,
upsample_factor=2, num_classes: int,
num_convs=4, upsample_factor: int = 2,
num_filters=256, num_convs: int = 4,
use_separable_conv=False, num_filters: int = 256,
activation='relu', use_separable_conv: bool = False,
use_sync_bn=False, activation: str = 'relu',
norm_momentum=0.99, use_sync_bn: bool = False,
norm_epsilon=0.001, norm_momentum: float = 0.99,
kernel_regularizer=None, norm_epsilon: float = 0.001,
bias_regularizer=None, kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
class_agnostic=False, bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
class_agnostic: bool = False,
**kwargs): **kwargs):
"""Initializes a mask head. """Initializes a mask head.
...@@ -272,7 +275,7 @@ class MaskHead(tf.keras.layers.Layer): ...@@ -272,7 +275,7 @@ class MaskHead(tf.keras.layers.Layer):
self._bn_axis = 1 self._bn_axis = 1
self._activation = tf_utils.get_activation(activation) self._activation = tf_utils.get_activation(activation)
def build(self, input_shape): def build(self, input_shape: Union[tf.TensorShape, List[tf.TensorShape]]):
"""Creates the variables of the head.""" """Creates the variables of the head."""
conv_op = (tf.keras.layers.SeparableConv2D conv_op = (tf.keras.layers.SeparableConv2D
if self._config_dict['use_separable_conv'] if self._config_dict['use_separable_conv']
...@@ -364,7 +367,7 @@ class MaskHead(tf.keras.layers.Layer): ...@@ -364,7 +367,7 @@ class MaskHead(tf.keras.layers.Layer):
super(MaskHead, self).build(input_shape) super(MaskHead, self).build(input_shape)
def call(self, inputs, training=None): def call(self, inputs: List[tf.Tensor], training: bool = None):
"""Forward pass of mask branch for the Mask-RCNN model. """Forward pass of mask branch for the Mask-RCNN model.
Args: Args:
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contains definitions of segmentation heads.""" """Contains definitions of segmentation heads."""
from typing import List, Union, Optional, Mapping
import tensorflow as tf import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
...@@ -25,22 +25,23 @@ from official.vision.beta.ops import spatial_transform_ops ...@@ -25,22 +25,23 @@ from official.vision.beta.ops import spatial_transform_ops
class SegmentationHead(tf.keras.layers.Layer): class SegmentationHead(tf.keras.layers.Layer):
"""Creates a segmentation head.""" """Creates a segmentation head."""
def __init__(self, def __init__(
num_classes, self,
level, num_classes: int,
num_convs=2, level: Union[int, str],
num_filters=256, num_convs: int = 2,
prediction_kernel_size=1, num_filters: int = 256,
upsample_factor=1, prediction_kernel_size: int = 1,
feature_fusion=None, upsample_factor: int = 1,
low_level=2, feature_fusion: Optional[str] = None,
low_level_num_filters=48, low_level: int = 2,
activation='relu', low_level_num_filters: int = 48,
use_sync_bn=False, activation: str = 'relu',
norm_momentum=0.99, use_sync_bn: bool = False,
norm_epsilon=0.001, norm_momentum: float = 0.99,
kernel_regularizer=None, norm_epsilon: float = 0.001,
bias_regularizer=None, kernel_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
bias_regularizer: Optional[tf.keras.regularizers.Regularizer] = None,
**kwargs): **kwargs):
"""Initializes a segmentation head. """Initializes a segmentation head.
...@@ -101,7 +102,7 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -101,7 +102,7 @@ class SegmentationHead(tf.keras.layers.Layer):
self._bn_axis = 1 self._bn_axis = 1
self._activation = tf_utils.get_activation(activation) self._activation = tf_utils.get_activation(activation)
def build(self, input_shape): def build(self, input_shape: Union[tf.TensorShape, List[tf.TensorShape]]):
"""Creates the variables of the segmentation head.""" """Creates the variables of the segmentation head."""
conv_op = tf.keras.layers.Conv2D conv_op = tf.keras.layers.Conv2D
conv_kwargs = { conv_kwargs = {
...@@ -159,7 +160,8 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -159,7 +160,8 @@ class SegmentationHead(tf.keras.layers.Layer):
super(SegmentationHead, self).build(input_shape) super(SegmentationHead, self).build(input_shape)
def call(self, backbone_output, decoder_output): def call(self, backbone_output: Mapping[str, tf.Tensor],
decoder_output: Mapping[str, tf.Tensor]):
"""Forward pass of the segmentation head. """Forward pass of the segmentation head.
Args: Args:
......
...@@ -25,8 +25,8 @@ class BoxSampler(tf.keras.layers.Layer): ...@@ -25,8 +25,8 @@ class BoxSampler(tf.keras.layers.Layer):
"""Creates a BoxSampler to sample positive and negative boxes.""" """Creates a BoxSampler to sample positive and negative boxes."""
def __init__(self, def __init__(self,
num_samples=512, num_samples: int = 512,
foreground_fraction=0.25, foreground_fraction: float = 0.25,
**kwargs): **kwargs):
"""Initializes a box sampler. """Initializes a box sampler.
...@@ -42,7 +42,8 @@ class BoxSampler(tf.keras.layers.Layer): ...@@ -42,7 +42,8 @@ class BoxSampler(tf.keras.layers.Layer):
} }
super(BoxSampler, self).__init__(**kwargs) super(BoxSampler, self).__init__(**kwargs)
def call(self, positive_matches, negative_matches, ignored_matches): def call(self, positive_matches: tf.Tensor, negative_matches: tf.Tensor,
ignored_matches: tf.Tensor):
"""Samples and selects positive and negative instances. """Samples and selects positive and negative instances.
Args: Args:
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contains definitions of generators to generate the final detections.""" """Contains definitions of generators to generate the final detections."""
from typing import Optional, Mapping
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
...@@ -21,13 +21,14 @@ from official.vision.beta.ops import box_ops ...@@ -21,13 +21,14 @@ from official.vision.beta.ops import box_ops
from official.vision.beta.ops import nms from official.vision.beta.ops import nms
def _generate_detections_v1(boxes, def _generate_detections_v1(boxes: tf.Tensor,
scores, scores: tf.Tensor,
attributes=None, attributes: Optional[Mapping[str,
pre_nms_top_k=5000, tf.Tensor]] = None,
pre_nms_score_threshold=0.05, pre_nms_top_k: int = 5000,
nms_iou_threshold=0.5, pre_nms_score_threshold: float = 0.05,
max_num_detections=100): nms_iou_threshold: float = 0.5,
max_num_detections: int = 100):
"""Generates the final detections given the model outputs. """Generates the final detections given the model outputs.
The implementation unrolls the batch dimension and process images one by one. The implementation unrolls the batch dimension and process images one by one.
...@@ -117,13 +118,14 @@ def _generate_detections_v1(boxes, ...@@ -117,13 +118,14 @@ def _generate_detections_v1(boxes,
return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections, nmsed_attributes return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections, nmsed_attributes
def _generate_detections_per_image(boxes, def _generate_detections_per_image(
scores, boxes: tf.Tensor,
attributes=None, scores: tf.Tensor,
pre_nms_top_k=5000, attributes: Optional[Mapping[str, tf.Tensor]] = None,
pre_nms_score_threshold=0.05, pre_nms_top_k: int = 5000,
nms_iou_threshold=0.5, pre_nms_score_threshold: float = 0.05,
max_num_detections=100): nms_iou_threshold: float = 0.5,
max_num_detections: int = 100):
"""Generates the final detections per image given the model outputs. """Generates the final detections per image given the model outputs.
Args: Args:
...@@ -225,7 +227,7 @@ def _generate_detections_per_image(boxes, ...@@ -225,7 +227,7 @@ def _generate_detections_per_image(boxes,
return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections, nmsed_attributes return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections, nmsed_attributes
def _select_top_k_scores(scores_in, pre_nms_num_detections): def _select_top_k_scores(scores_in: tf.Tensor, pre_nms_num_detections: int):
"""Selects top_k scores and indices for each class. """Selects top_k scores and indices for each class.
Args: Args:
...@@ -255,12 +257,12 @@ def _select_top_k_scores(scores_in, pre_nms_num_detections): ...@@ -255,12 +257,12 @@ def _select_top_k_scores(scores_in, pre_nms_num_detections):
[0, 2, 1]), tf.transpose(top_k_indices, [0, 2, 1]) [0, 2, 1]), tf.transpose(top_k_indices, [0, 2, 1])
def _generate_detections_v2(boxes, def _generate_detections_v2(boxes: tf.Tensor,
scores, scores: tf.Tensor,
pre_nms_top_k=5000, pre_nms_top_k: int = 5000,
pre_nms_score_threshold=0.05, pre_nms_score_threshold: float = 0.05,
nms_iou_threshold=0.5, nms_iou_threshold: float = 0.5,
max_num_detections=100): max_num_detections: int = 100):
"""Generates the final detections given the model outputs. """Generates the final detections given the model outputs.
This implementation unrolls classes dimension while using the tf.while_loop This implementation unrolls classes dimension while using the tf.while_loop
...@@ -337,11 +339,10 @@ def _generate_detections_v2(boxes, ...@@ -337,11 +339,10 @@ def _generate_detections_v2(boxes,
return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
def _generate_detections_batched(boxes, def _generate_detections_batched(boxes: tf.Tensor, scores: tf.Tensor,
scores, pre_nms_score_threshold: float,
pre_nms_score_threshold, nms_iou_threshold: float,
nms_iou_threshold, max_num_detections: int):
max_num_detections):
"""Generates detected boxes with scores and classes for one-stage detector. """Generates detected boxes with scores and classes for one-stage detector.
The function takes output of multi-level ConvNets and anchor boxes and The function takes output of multi-level ConvNets and anchor boxes and
...@@ -393,12 +394,12 @@ class DetectionGenerator(tf.keras.layers.Layer): ...@@ -393,12 +394,12 @@ class DetectionGenerator(tf.keras.layers.Layer):
"""Generates the final detected boxes with scores and classes.""" """Generates the final detected boxes with scores and classes."""
def __init__(self, def __init__(self,
apply_nms=True, apply_nms: bool = True,
pre_nms_top_k=5000, pre_nms_top_k: int = 5000,
pre_nms_score_threshold=0.05, pre_nms_score_threshold: float = 0.05,
nms_iou_threshold=0.5, nms_iou_threshold: float = 0.5,
max_num_detections=100, max_num_detections: int = 100,
use_batched_nms=False, use_batched_nms: bool = False,
**kwargs): **kwargs):
"""Initializes a detection generator. """Initializes a detection generator.
...@@ -427,11 +428,8 @@ class DetectionGenerator(tf.keras.layers.Layer): ...@@ -427,11 +428,8 @@ class DetectionGenerator(tf.keras.layers.Layer):
} }
super(DetectionGenerator, self).__init__(**kwargs) super(DetectionGenerator, self).__init__(**kwargs)
def __call__(self, def __call__(self, raw_boxes: tf.Tensor, raw_scores: tf.Tensor,
raw_boxes, anchor_boxes: tf.Tensor, image_shape: tf.Tensor):
raw_scores,
anchor_boxes,
image_shape):
"""Generates final detections. """Generates final detections.
Args: Args:
...@@ -546,12 +544,12 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer): ...@@ -546,12 +544,12 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
"""Generates detected boxes with scores and classes for one-stage detector.""" """Generates detected boxes with scores and classes for one-stage detector."""
def __init__(self, def __init__(self,
apply_nms=True, apply_nms: bool = True,
pre_nms_top_k=5000, pre_nms_top_k: int = 5000,
pre_nms_score_threshold=0.05, pre_nms_score_threshold: float = 0.05,
nms_iou_threshold=0.5, nms_iou_threshold: float = 0.5,
max_num_detections=100, max_num_detections: int = 100,
use_batched_nms=False, use_batched_nms: bool = False,
**kwargs): **kwargs):
"""Initializes a multi-level detection generator. """Initializes a multi-level detection generator.
...@@ -581,11 +579,11 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer): ...@@ -581,11 +579,11 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
super(MultilevelDetectionGenerator, self).__init__(**kwargs) super(MultilevelDetectionGenerator, self).__init__(**kwargs)
def __call__(self, def __call__(self,
raw_boxes, raw_boxes: Mapping[str, tf.Tensor],
raw_scores, raw_scores: Mapping[str, tf.Tensor],
anchor_boxes, anchor_boxes: tf.Tensor,
image_shape, image_shape: tf.Tensor,
raw_attributes=None): raw_attributes: Mapping[str, tf.Tensor] = None):
"""Generates final detections. """Generates final detections.
Args: Args:
...@@ -600,11 +598,10 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer): ...@@ -600,11 +598,10 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
image_shape: A `tf.Tensor` of shape of [batch_size, 2] storing the image image_shape: A `tf.Tensor` of shape of [batch_size, 2] storing the image
height and width w.r.t. the scaled image, i.e. the same image space as height and width w.r.t. the scaled image, i.e. the same image space as
`box_outputs` and `anchor_boxes`. `box_outputs` and `anchor_boxes`.
raw_attributes: If not None, a `dict` of raw_attributes: If not None, a `dict` of (attribute_name,
(attribute_name, attribute_prediction) pairs. `attribute_prediction` attribute_prediction) pairs. `attribute_prediction` is a dict that
is a dict that contains keys representing FPN levels and values contains keys representing FPN levels and values representing tenors of
representing tenors of shape `[batch, feature_h, feature_w, shape `[batch, feature_h, feature_w, num_anchors * attribute_size]`.
num_anchors * attribute_size]`.
Returns: Returns:
If `apply_nms` = True, the return is a dictionary with keys: If `apply_nms` = True, the return is a dictionary with keys:
......
...@@ -20,13 +20,13 @@ import tensorflow as tf ...@@ -20,13 +20,13 @@ import tensorflow as tf
from official.vision.beta.ops import spatial_transform_ops from official.vision.beta.ops import spatial_transform_ops
def _sample_and_crop_foreground_masks(candidate_rois, def _sample_and_crop_foreground_masks(candidate_rois: tf.Tensor,
candidate_gt_boxes, candidate_gt_boxes: tf.Tensor,
candidate_gt_classes, candidate_gt_classes: tf.Tensor,
candidate_gt_indices, candidate_gt_indices: tf.Tensor,
gt_masks, gt_masks: tf.Tensor,
num_sampled_masks=128, num_sampled_masks: int = 128,
mask_target_size=28): mask_target_size: int = 28):
"""Samples and creates cropped foreground masks for training. """Samples and creates cropped foreground masks for training.
Args: Args:
...@@ -104,22 +104,16 @@ def _sample_and_crop_foreground_masks(candidate_rois, ...@@ -104,22 +104,16 @@ def _sample_and_crop_foreground_masks(candidate_rois,
class MaskSampler(tf.keras.layers.Layer): class MaskSampler(tf.keras.layers.Layer):
"""Samples and creates mask training targets.""" """Samples and creates mask training targets."""
def __init__(self, def __init__(self, mask_target_size: int, num_sampled_masks: int, **kwargs):
mask_target_size,
num_sampled_masks,
**kwargs):
self._config_dict = { self._config_dict = {
'mask_target_size': mask_target_size, 'mask_target_size': mask_target_size,
'num_sampled_masks': num_sampled_masks, 'num_sampled_masks': num_sampled_masks,
} }
super(MaskSampler, self).__init__(**kwargs) super(MaskSampler, self).__init__(**kwargs)
def call(self, def call(self, candidate_rois: tf.Tensor, candidate_gt_boxes: tf.Tensor,
candidate_rois, candidate_gt_classes: tf.Tensor, candidate_gt_indices: tf.Tensor,
candidate_gt_boxes, gt_masks: tf.Tensor):
candidate_gt_classes,
candidate_gt_indices,
gt_masks):
"""Samples and creates mask targets for training. """Samples and creates mask targets for training.
Args: Args:
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
"""Contains definitions of ROI aligner.""" """Contains definitions of ROI aligner."""
from typing import Mapping
import tensorflow as tf import tensorflow as tf
from official.vision.beta.ops import spatial_transform_ops from official.vision.beta.ops import spatial_transform_ops
...@@ -23,10 +24,7 @@ from official.vision.beta.ops import spatial_transform_ops ...@@ -23,10 +24,7 @@ from official.vision.beta.ops import spatial_transform_ops
class MultilevelROIAligner(tf.keras.layers.Layer): class MultilevelROIAligner(tf.keras.layers.Layer):
"""Performs ROIAlign for the second stage processing.""" """Performs ROIAlign for the second stage processing."""
def __init__(self, def __init__(self, crop_size: int = 7, sample_offset: float = 0.5, **kwargs):
crop_size=7,
sample_offset=0.5,
**kwargs):
"""Initializes a ROI aligner. """Initializes a ROI aligner.
Args: Args:
...@@ -40,7 +38,10 @@ class MultilevelROIAligner(tf.keras.layers.Layer): ...@@ -40,7 +38,10 @@ class MultilevelROIAligner(tf.keras.layers.Layer):
} }
super(MultilevelROIAligner, self).__init__(**kwargs) super(MultilevelROIAligner, self).__init__(**kwargs)
def call(self, features, boxes, training=None): def call(self,
features: Mapping[str, tf.Tensor],
boxes: tf.Tensor,
training: bool = None):
"""Generates ROIs. """Generates ROIs.
Args: Args:
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
"""Contains definitions of ROI generator.""" """Contains definitions of ROI generator."""
from typing import Optional, Mapping
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
...@@ -21,19 +21,19 @@ from official.vision.beta.ops import box_ops ...@@ -21,19 +21,19 @@ from official.vision.beta.ops import box_ops
from official.vision.beta.ops import nms from official.vision.beta.ops import nms
def _multilevel_propose_rois(raw_boxes, def _multilevel_propose_rois(raw_boxes: Mapping[str, tf.Tensor],
raw_scores, raw_scores: Mapping[str, tf.Tensor],
anchor_boxes, anchor_boxes: Mapping[str, tf.Tensor],
image_shape, image_shape: tf.Tensor,
pre_nms_top_k=2000, pre_nms_top_k: int = 2000,
pre_nms_score_threshold=0.0, pre_nms_score_threshold: float = 0.0,
pre_nms_min_size_threshold=0.0, pre_nms_min_size_threshold: float = 0.0,
nms_iou_threshold=0.7, nms_iou_threshold: float = 0.7,
num_proposals=1000, num_proposals: int = 1000,
use_batched_nms=False, use_batched_nms: bool = False,
decode_boxes=True, decode_boxes: bool = True,
clip_boxes=True, clip_boxes: bool = True,
apply_sigmoid_to_score=True): apply_sigmoid_to_score: bool = True):
"""Proposes RoIs given a group of candidates from different FPN levels. """Proposes RoIs given a group of candidates from different FPN levels.
The following describes the steps: The following describes the steps:
...@@ -181,17 +181,17 @@ class MultilevelROIGenerator(tf.keras.layers.Layer): ...@@ -181,17 +181,17 @@ class MultilevelROIGenerator(tf.keras.layers.Layer):
"""Proposes RoIs for the second stage processing.""" """Proposes RoIs for the second stage processing."""
def __init__(self, def __init__(self,
pre_nms_top_k=2000, pre_nms_top_k: int = 2000,
pre_nms_score_threshold=0.0, pre_nms_score_threshold: float = 0.0,
pre_nms_min_size_threshold=0.0, pre_nms_min_size_threshold: float = 0.0,
nms_iou_threshold=0.7, nms_iou_threshold: float = 0.7,
num_proposals=1000, num_proposals: int = 1000,
test_pre_nms_top_k=1000, test_pre_nms_top_k: int = 1000,
test_pre_nms_score_threshold=0.0, test_pre_nms_score_threshold: float = 0.0,
test_pre_nms_min_size_threshold=0.0, test_pre_nms_min_size_threshold: float = 0.0,
test_nms_iou_threshold=0.7, test_nms_iou_threshold: float = 0.7,
test_num_proposals=1000, test_num_proposals: int = 1000,
use_batched_nms=False, use_batched_nms: bool = False,
**kwargs): **kwargs):
"""Initializes a ROI generator. """Initializes a ROI generator.
...@@ -240,11 +240,11 @@ class MultilevelROIGenerator(tf.keras.layers.Layer): ...@@ -240,11 +240,11 @@ class MultilevelROIGenerator(tf.keras.layers.Layer):
super(MultilevelROIGenerator, self).__init__(**kwargs) super(MultilevelROIGenerator, self).__init__(**kwargs)
def call(self, def call(self,
raw_boxes, raw_boxes: Mapping[str, tf.Tensor],
raw_scores, raw_scores: Mapping[str, tf.Tensor],
anchor_boxes, anchor_boxes: Mapping[str, tf.Tensor],
image_shape, image_shape: tf.Tensor,
training=None): training: Optional[bool] = None):
"""Proposes RoIs given a group of candidates from different FPN levels. """Proposes RoIs given a group of candidates from different FPN levels.
The following describes the steps: The following describes the steps:
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
"""Contains definitions of ROI sampler.""" """Contains definitions of ROI sampler."""
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
...@@ -26,12 +25,12 @@ class ROISampler(tf.keras.layers.Layer): ...@@ -26,12 +25,12 @@ class ROISampler(tf.keras.layers.Layer):
"""Samples ROIs and assigns targets to the sampled ROIs.""" """Samples ROIs and assigns targets to the sampled ROIs."""
def __init__(self, def __init__(self,
mix_gt_boxes=True, mix_gt_boxes: bool = True,
num_sampled_rois=512, num_sampled_rois: int = 512,
foreground_fraction=0.25, foreground_fraction: float = 0.25,
foreground_iou_threshold=0.5, foreground_iou_threshold: float = 0.5,
background_iou_high_threshold=0.5, background_iou_high_threshold: float = 0.5,
background_iou_low_threshold=0, background_iou_low_threshold: float = 0,
**kwargs): **kwargs):
"""Initializes a ROI sampler. """Initializes a ROI sampler.
...@@ -73,7 +72,7 @@ class ROISampler(tf.keras.layers.Layer): ...@@ -73,7 +72,7 @@ class ROISampler(tf.keras.layers.Layer):
num_sampled_rois, foreground_fraction) num_sampled_rois, foreground_fraction)
super(ROISampler, self).__init__(**kwargs) super(ROISampler, self).__init__(**kwargs)
def call(self, boxes, gt_boxes, gt_classes): def call(self, boxes: tf.Tensor, gt_boxes: tf.Tensor, gt_classes: tf.Tensor):
"""Assigns the proposals with groundtruth classes and performs subsmpling. """Assigns the proposals with groundtruth classes and performs subsmpling.
Given `proposed_boxes`, `gt_boxes`, and `gt_classes`, the function uses the Given `proposed_boxes`, `gt_boxes`, and `gt_classes`, the function uses the
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Image classification task definition.""" """Image classification task definition."""
from typing import Any, Optional, List, Tuple
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
...@@ -51,7 +51,7 @@ class ImageClassificationTask(base_task.Task): ...@@ -51,7 +51,7 @@ class ImageClassificationTask(base_task.Task):
return model return model
def initialize(self, model: tf.keras.Model): def initialize(self, model: tf.keras.Model):
"""Loading pretrained checkpoint.""" """Loads pretrained checkpoint."""
if not self.task_config.init_checkpoint: if not self.task_config.init_checkpoint:
return return
...@@ -75,7 +75,9 @@ class ImageClassificationTask(base_task.Task): ...@@ -75,7 +75,9 @@ class ImageClassificationTask(base_task.Task):
logging.info('Finished loading pretrained checkpoint from %s', logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file) ckpt_dir_or_file)
def build_inputs(self, params, input_context=None): def build_inputs(self,
params: exp_cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None):
"""Builds classification input.""" """Builds classification input."""
num_classes = self.task_config.model.num_classes num_classes = self.task_config.model.num_classes
...@@ -112,13 +114,16 @@ class ImageClassificationTask(base_task.Task): ...@@ -112,13 +114,16 @@ class ImageClassificationTask(base_task.Task):
return dataset return dataset
def build_losses(self, labels, model_outputs, aux_losses=None): def build_losses(self,
"""Sparse categorical cross entropy loss. labels: tf.Tensor,
model_outputs: tf.Tensor,
aux_losses: Optional[Any] = None):
"""Builds sparse categorical cross entropy loss.
Args: Args:
labels: labels. labels: Input groundtruth labels.
model_outputs: Output logits of the classifier. model_outputs: Output logits of the classifier.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model. aux_losses: The auxiliarly loss tensors, i.e. `losses` in tf.keras.Model.
Returns: Returns:
The total loss tensor. The total loss tensor.
...@@ -140,7 +145,7 @@ class ImageClassificationTask(base_task.Task): ...@@ -140,7 +145,7 @@ class ImageClassificationTask(base_task.Task):
return total_loss return total_loss
def build_metrics(self, training=True): def build_metrics(self, training: bool = True):
"""Gets streaming metrics for training/validation.""" """Gets streaming metrics for training/validation."""
k = self.task_config.evaluation.top_k k = self.task_config.evaluation.top_k
if self.task_config.losses.one_hot: if self.task_config.losses.one_hot:
...@@ -155,14 +160,18 @@ class ImageClassificationTask(base_task.Task): ...@@ -155,14 +160,18 @@ class ImageClassificationTask(base_task.Task):
k=k, name='top_{}_accuracy'.format(k))] k=k, name='top_{}_accuracy'.format(k))]
return metrics return metrics
def train_step(self, inputs, model, optimizer, metrics=None): def train_step(self,
inputs: Tuple[Any, Any],
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
metrics: Optional[List[Any]] = None):
"""Does forward and backward. """Does forward and backward.
Args: Args:
inputs: a dictionary of input tensors. inputs: A tuple of of input tensors of (features, labels).
model: the model, forward pass definition. model: A tf.keras.Model instance.
optimizer: the optimizer for this training step. optimizer: The optimizer for this training step.
metrics: a nested structure of metrics objects. metrics: A nested structure of metrics objects.
Returns: Returns:
A dictionary of logs. A dictionary of logs.
...@@ -209,13 +218,16 @@ class ImageClassificationTask(base_task.Task): ...@@ -209,13 +218,16 @@ class ImageClassificationTask(base_task.Task):
logs.update({m.name: m.result() for m in model.metrics}) logs.update({m.name: m.result() for m in model.metrics})
return logs return logs
def validation_step(self, inputs, model, metrics=None): def validation_step(self,
"""Validatation step. inputs: Tuple[Any, Any],
model: tf.keras.Model,
metrics: Optional[List[Any]] = None):
"""Runs validatation step.
Args: Args:
inputs: a dictionary of input tensors. inputs: A tuple of of input tensors of (features, labels).
model: the keras.Model. model: A tf.keras.Model instance.
metrics: a nested structure of metrics objects. metrics: A nested structure of metrics objects.
Returns: Returns:
A dictionary of logs. A dictionary of logs.
...@@ -237,6 +249,6 @@ class ImageClassificationTask(base_task.Task): ...@@ -237,6 +249,6 @@ class ImageClassificationTask(base_task.Task):
logs.update({m.name: m.result() for m in model.metrics}) logs.update({m.name: m.result() for m in model.metrics})
return logs return logs
def inference_step(self, inputs, model): def inference_step(self, inputs: tf.Tensor, model: tf.keras.Model):
"""Performs the forward step.""" """Performs the forward step."""
return model(inputs, training=False) return model(inputs, training=False)
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""RetinaNet task definition.""" """RetinaNet task definition."""
from typing import Any, Optional, List, Tuple, Mapping
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
...@@ -30,7 +30,8 @@ from official.vision.beta.losses import maskrcnn_losses ...@@ -30,7 +30,8 @@ from official.vision.beta.losses import maskrcnn_losses
from official.vision.beta.modeling import factory from official.vision.beta.modeling import factory
def zero_out_disallowed_class_ids(batch_class_ids, allowed_class_ids): def zero_out_disallowed_class_ids(batch_class_ids: tf.Tensor,
allowed_class_ids: List[int]):
"""Zero out IDs of classes not in allowed_class_ids. """Zero out IDs of classes not in allowed_class_ids.
Args: Args:
...@@ -106,7 +107,9 @@ class MaskRCNNTask(base_task.Task): ...@@ -106,7 +107,9 @@ class MaskRCNNTask(base_task.Task):
logging.info('Finished loading pretrained checkpoint from %s', logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file) ckpt_dir_or_file)
def build_inputs(self, params, input_context=None): def build_inputs(self,
params: exp_cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None):
"""Build input dataset.""" """Build input dataset."""
decoder_cfg = params.decoder.get() decoder_cfg = params.decoder.get()
if params.decoder.type == 'simple_decoder': if params.decoder.type == 'simple_decoder':
...@@ -152,7 +155,10 @@ class MaskRCNNTask(base_task.Task): ...@@ -152,7 +155,10 @@ class MaskRCNNTask(base_task.Task):
return dataset return dataset
def build_losses(self, outputs, labels, aux_losses=None): def build_losses(self,
outputs: Mapping[str, Any],
labels: Mapping[str, Any],
aux_losses: Optional[Any] = None):
"""Build Mask R-CNN losses.""" """Build Mask R-CNN losses."""
params = self.task_config params = self.task_config
...@@ -218,7 +224,7 @@ class MaskRCNNTask(base_task.Task): ...@@ -218,7 +224,7 @@ class MaskRCNNTask(base_task.Task):
} }
return losses return losses
def build_metrics(self, training=True): def build_metrics(self, training: bool = True):
"""Build detection metrics.""" """Build detection metrics."""
metrics = [] metrics = []
if training: if training:
...@@ -242,7 +248,11 @@ class MaskRCNNTask(base_task.Task): ...@@ -242,7 +248,11 @@ class MaskRCNNTask(base_task.Task):
return metrics return metrics
def train_step(self, inputs, model, optimizer, metrics=None): def train_step(self,
inputs: Tuple[Any, Any],
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
metrics: Optional[List[Any]] = None):
"""Does forward and backward. """Does forward and backward.
Args: Args:
...@@ -294,7 +304,10 @@ class MaskRCNNTask(base_task.Task): ...@@ -294,7 +304,10 @@ class MaskRCNNTask(base_task.Task):
return logs return logs
def validation_step(self, inputs, model, metrics=None): def validation_step(self,
inputs: Tuple[Any, Any],
model: tf.keras.Model,
metrics: Optional[List[Any]] = None):
"""Validatation step. """Validatation step.
Args: Args:
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""RetinaNet task definition.""" """RetinaNet task definition."""
from typing import Any, Optional, List, Tuple, Mapping
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
...@@ -84,7 +84,9 @@ class RetinaNetTask(base_task.Task): ...@@ -84,7 +84,9 @@ class RetinaNetTask(base_task.Task):
logging.info('Finished loading pretrained checkpoint from %s', logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file) ckpt_dir_or_file)
def build_inputs(self, params, input_context=None): def build_inputs(self,
params: exp_cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None):
"""Build input dataset.""" """Build input dataset."""
if params.tfds_name: if params.tfds_name:
...@@ -131,7 +133,10 @@ class RetinaNetTask(base_task.Task): ...@@ -131,7 +133,10 @@ class RetinaNetTask(base_task.Task):
return dataset return dataset
def build_losses(self, outputs, labels, aux_losses=None): def build_losses(self,
outputs: Mapping[str, Any],
labels: Mapping[str, Any],
aux_losses: Optional[Any] = None):
"""Build RetinaNet losses.""" """Build RetinaNet losses."""
params = self.task_config params = self.task_config
cls_loss_fn = keras_cv.losses.FocalLoss( cls_loss_fn = keras_cv.losses.FocalLoss(
...@@ -172,7 +177,7 @@ class RetinaNetTask(base_task.Task): ...@@ -172,7 +177,7 @@ class RetinaNetTask(base_task.Task):
return total_loss, cls_loss, box_loss, model_loss return total_loss, cls_loss, box_loss, model_loss
def build_metrics(self, training=True): def build_metrics(self, training: bool = True):
"""Build detection metrics.""" """Build detection metrics."""
metrics = [] metrics = []
metric_names = ['total_loss', 'cls_loss', 'box_loss', 'model_loss'] metric_names = ['total_loss', 'cls_loss', 'box_loss', 'model_loss']
...@@ -190,7 +195,11 @@ class RetinaNetTask(base_task.Task): ...@@ -190,7 +195,11 @@ class RetinaNetTask(base_task.Task):
return metrics return metrics
def train_step(self, inputs, model, optimizer, metrics=None): def train_step(self,
inputs: Tuple[Any, Any],
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
metrics: Optional[List[Any]] = None):
"""Does forward and backward. """Does forward and backward.
Args: Args:
...@@ -241,7 +250,10 @@ class RetinaNetTask(base_task.Task): ...@@ -241,7 +250,10 @@ class RetinaNetTask(base_task.Task):
return logs return logs
def validation_step(self, inputs, model, metrics=None): def validation_step(self,
inputs: Tuple[Any, Any],
model: tf.keras.Model,
metrics: Optional[List[Any]] = None):
"""Validatation step. """Validatation step.
Args: Args:
......
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Image segmentation task definition.""" """Image segmentation task definition."""
from typing import Any, Optional, List, Tuple, Mapping, Union
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
...@@ -79,7 +79,9 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -79,7 +79,9 @@ class SemanticSegmentationTask(base_task.Task):
logging.info('Finished loading pretrained checkpoint from %s', logging.info('Finished loading pretrained checkpoint from %s',
ckpt_dir_or_file) ckpt_dir_or_file)
def build_inputs(self, params, input_context=None): def build_inputs(self,
params: exp_cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None):
"""Builds classification input.""" """Builds classification input."""
ignore_label = self.task_config.losses.ignore_label ignore_label = self.task_config.losses.ignore_label
...@@ -114,7 +116,10 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -114,7 +116,10 @@ class SemanticSegmentationTask(base_task.Task):
return dataset return dataset
def build_losses(self, labels, model_outputs, aux_losses=None): def build_losses(self,
labels: Mapping[str, tf.Tensor],
model_outputs: Union[Mapping[str, tf.Tensor], tf.Tensor],
aux_losses: Optional[Any] = None):
"""Segmentation loss. """Segmentation loss.
Args: Args:
...@@ -140,7 +145,7 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -140,7 +145,7 @@ class SemanticSegmentationTask(base_task.Task):
return total_loss return total_loss
def build_metrics(self, training=True): def build_metrics(self, training: bool = True):
"""Gets streaming metrics for training/validation.""" """Gets streaming metrics for training/validation."""
metrics = [] metrics = []
if training and self.task_config.evaluation.report_train_mean_iou: if training and self.task_config.evaluation.report_train_mean_iou:
...@@ -159,7 +164,11 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -159,7 +164,11 @@ class SemanticSegmentationTask(base_task.Task):
return metrics return metrics
def train_step(self, inputs, model, optimizer, metrics=None): def train_step(self,
inputs: Tuple[Any, Any],
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
metrics: Optional[List[Any]] = None):
"""Does forward and backward. """Does forward and backward.
Args: Args:
...@@ -214,7 +223,10 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -214,7 +223,10 @@ class SemanticSegmentationTask(base_task.Task):
return logs return logs
def validation_step(self, inputs, model, metrics=None): def validation_step(self,
inputs: Tuple[Any, Any],
model: tf.keras.Model,
metrics: Optional[List[Any]] = None):
"""Validatation step. """Validatation step.
Args: Args:
...@@ -251,7 +263,7 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -251,7 +263,7 @@ class SemanticSegmentationTask(base_task.Task):
return logs return logs
def inference_step(self, inputs, model): def inference_step(self, inputs: tf.Tensor, model: tf.keras.Model):
"""Performs the forward step.""" """Performs the forward step."""
return model(inputs, training=False) return model(inputs, training=False)
......
...@@ -12,8 +12,9 @@ ...@@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Video classification task definition.""" """Video classification task definition."""
from typing import Any, Optional, List, Tuple
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
from official.core import base_task from official.core import base_task
...@@ -68,7 +69,9 @@ class VideoClassificationTask(base_task.Task): ...@@ -68,7 +69,9 @@ class VideoClassificationTask(base_task.Task):
tf.io.VarLenFeature(dtype=tf.float32)) tf.io.VarLenFeature(dtype=tf.float32))
return decoder.decode return decoder.decode
def build_inputs(self, params: exp_cfg.DataConfig, input_context=None): def build_inputs(self,
params: exp_cfg.DataConfig,
input_context: Optional[tf.distribute.InputContext] = None):
"""Builds classification input.""" """Builds classification input."""
parser = video_input.Parser(input_params=params) parser = video_input.Parser(input_params=params)
...@@ -85,7 +88,10 @@ class VideoClassificationTask(base_task.Task): ...@@ -85,7 +88,10 @@ class VideoClassificationTask(base_task.Task):
return dataset return dataset
def build_losses(self, labels, model_outputs, aux_losses=None): def build_losses(self,
labels: Any,
model_outputs: Any,
aux_losses: Optional[Any] = None):
"""Sparse categorical cross entropy loss. """Sparse categorical cross entropy loss.
Args: Args:
...@@ -132,7 +138,7 @@ class VideoClassificationTask(base_task.Task): ...@@ -132,7 +138,7 @@ class VideoClassificationTask(base_task.Task):
return all_losses return all_losses
def build_metrics(self, training=True): def build_metrics(self, training: bool = True):
"""Gets streaming metrics for training/validation.""" """Gets streaming metrics for training/validation."""
if self.task_config.losses.one_hot: if self.task_config.losses.one_hot:
metrics = [ metrics = [
...@@ -168,7 +174,8 @@ class VideoClassificationTask(base_task.Task): ...@@ -168,7 +174,8 @@ class VideoClassificationTask(base_task.Task):
] ]
return metrics return metrics
def process_metrics(self, metrics, labels, model_outputs): def process_metrics(self, metrics: List[Any], labels: Any,
model_outputs: Any):
"""Process and update metrics. """Process and update metrics.
Called when using custom training loop API. Called when using custom training loop API.
...@@ -183,7 +190,11 @@ class VideoClassificationTask(base_task.Task): ...@@ -183,7 +190,11 @@ class VideoClassificationTask(base_task.Task):
for metric in metrics: for metric in metrics:
metric.update_state(labels, model_outputs) metric.update_state(labels, model_outputs)
def train_step(self, inputs, model, optimizer, metrics=None): def train_step(self,
inputs: Tuple[Any, Any],
model: tf.keras.Model,
optimizer: tf.keras.optimizers.Optimizer,
metrics: Optional[List[Any]] = None):
"""Does forward and backward. """Does forward and backward.
Args: Args:
...@@ -240,7 +251,10 @@ class VideoClassificationTask(base_task.Task): ...@@ -240,7 +251,10 @@ class VideoClassificationTask(base_task.Task):
logs.update({m.name: m.result() for m in model.metrics}) logs.update({m.name: m.result() for m in model.metrics})
return logs return logs
def validation_step(self, inputs, model, metrics=None): def validation_step(self,
inputs: Tuple[Any, Any],
model: tf.keras.Model,
metrics: Optional[List[Any]] = None):
"""Validatation step. """Validatation step.
Args: Args:
...@@ -266,7 +280,7 @@ class VideoClassificationTask(base_task.Task): ...@@ -266,7 +280,7 @@ class VideoClassificationTask(base_task.Task):
logs.update({m.name: m.result() for m in model.metrics}) logs.update({m.name: m.result() for m in model.metrics})
return logs return logs
def inference_step(self, features, model): def inference_step(self, features: tf.Tensor, model: tf.keras.Model):
"""Performs the forward step.""" """Performs the forward step."""
outputs = model(features, training=False) outputs = model(features, training=False)
if self.task_config.train_data.is_multilabel: if self.task_config.train_data.is_multilabel:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment