Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
c2e19c97
Commit
c2e19c97
authored
Apr 16, 2021
by
Fan Yang
Committed by
A. Unique TensorFlower
Apr 16, 2021
Browse files
Internal change.
PiperOrigin-RevId: 368935233
parent
127c9d80
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
314 additions
and
249 deletions
+314
-249
official/vision/beta/modeling/heads/dense_prediction_heads.py
...cial/vision/beta/modeling/heads/dense_prediction_heads.py
+38
-33
official/vision/beta/modeling/heads/instance_heads.py
official/vision/beta/modeling/heads/instance_heads.py
+35
-32
official/vision/beta/modeling/heads/segmentation_heads.py
official/vision/beta/modeling/heads/segmentation_heads.py
+22
-20
official/vision/beta/modeling/layers/box_sampler.py
official/vision/beta/modeling/layers/box_sampler.py
+4
-3
official/vision/beta/modeling/layers/detection_generator.py
official/vision/beta/modeling/layers/detection_generator.py
+51
-54
official/vision/beta/modeling/layers/mask_sampler.py
official/vision/beta/modeling/layers/mask_sampler.py
+11
-17
official/vision/beta/modeling/layers/roi_aligner.py
official/vision/beta/modeling/layers/roi_aligner.py
+6
-5
official/vision/beta/modeling/layers/roi_generator.py
official/vision/beta/modeling/layers/roi_generator.py
+30
-30
official/vision/beta/modeling/layers/roi_sampler.py
official/vision/beta/modeling/layers/roi_sampler.py
+7
-8
official/vision/beta/tasks/image_classification.py
official/vision/beta/tasks/image_classification.py
+31
-19
official/vision/beta/tasks/maskrcnn.py
official/vision/beta/tasks/maskrcnn.py
+20
-7
official/vision/beta/tasks/retinanet.py
official/vision/beta/tasks/retinanet.py
+18
-6
official/vision/beta/tasks/semantic_segmentation.py
official/vision/beta/tasks/semantic_segmentation.py
+19
-7
official/vision/beta/tasks/video_classification.py
official/vision/beta/tasks/video_classification.py
+22
-8
No files found.
official/vision/beta/modeling/heads/dense_prediction_heads.py
View file @
c2e19c97
...
...
@@ -14,7 +14,10 @@
"""Contains definitions of dense prediction heads."""
from
typing
import
List
,
Mapping
,
Optional
,
Tuple
,
Union
# Import libraries
import
numpy
as
np
import
tensorflow
as
tf
...
...
@@ -25,22 +28,23 @@ from official.modeling import tf_utils
class
RetinaNetHead
(
tf
.
keras
.
layers
.
Layer
):
"""Creates a RetinaNet head."""
def
__init__
(
self
,
min_level
,
max_level
,
num_classes
,
num_anchors_per_location
,
num_convs
=
4
,
num_filters
=
256
,
attribute_heads
=
None
,
use_separable_conv
=
False
,
activation
=
'relu'
,
use_sync_bn
=
False
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
**
kwargs
):
def
__init__
(
self
,
min_level
:
int
,
max_level
:
int
,
num_classes
:
int
,
num_anchors_per_location
:
int
,
num_convs
:
int
=
4
,
num_filters
:
int
=
256
,
attribute_heads
:
Mapping
[
str
,
Tuple
[
str
,
int
]]
=
None
,
use_separable_conv
:
bool
=
False
,
activation
:
str
=
'relu'
,
use_sync_bn
:
bool
=
False
,
norm_momentum
:
float
=
0.99
,
norm_epsilon
:
float
=
0.001
,
kernel_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
,
bias_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
,
**
kwargs
):
"""Initializes a RetinaNet head.
Args:
...
...
@@ -93,7 +97,7 @@ class RetinaNetHead(tf.keras.layers.Layer):
self
.
_bn_axis
=
1
self
.
_activation
=
tf_utils
.
get_activation
(
activation
)
def
build
(
self
,
input_shape
):
def
build
(
self
,
input_shape
:
Union
[
tf
.
TensorShape
,
List
[
tf
.
TensorShape
]]
):
"""Creates the variables of the head."""
conv_op
=
(
tf
.
keras
.
layers
.
SeparableConv2D
if
self
.
_config_dict
[
'use_separable_conv'
]
...
...
@@ -239,7 +243,7 @@ class RetinaNetHead(tf.keras.layers.Layer):
super
(
RetinaNetHead
,
self
).
build
(
input_shape
)
def
call
(
self
,
features
):
def
call
(
self
,
features
:
Mapping
[
str
,
tf
.
Tensor
]
):
"""Forward pass of the RetinaNet head.
Args:
...
...
@@ -325,20 +329,21 @@ class RetinaNetHead(tf.keras.layers.Layer):
class
RPNHead
(
tf
.
keras
.
layers
.
Layer
):
"""Creates a Region Proposal Network (RPN) head."""
def
__init__
(
self
,
min_level
,
max_level
,
num_anchors_per_location
,
num_convs
=
1
,
num_filters
=
256
,
use_separable_conv
=
False
,
activation
=
'relu'
,
use_sync_bn
=
False
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
**
kwargs
):
def
__init__
(
self
,
min_level
:
int
,
max_level
:
int
,
num_anchors_per_location
:
int
,
num_convs
:
int
=
1
,
num_filters
:
int
=
256
,
use_separable_conv
:
bool
=
False
,
activation
:
str
=
'relu'
,
use_sync_bn
:
bool
=
False
,
norm_momentum
:
float
=
0.99
,
norm_epsilon
:
float
=
0.001
,
kernel_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
,
bias_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
,
**
kwargs
):
"""Initializes a Region Proposal Network head.
Args:
...
...
@@ -457,7 +462,7 @@ class RPNHead(tf.keras.layers.Layer):
super
(
RPNHead
,
self
).
build
(
input_shape
)
def
call
(
self
,
features
):
def
call
(
self
,
features
:
Mapping
[
str
,
tf
.
Tensor
]
):
"""Forward pass of the RPN head.
Args:
...
...
official/vision/beta/modeling/heads/instance_heads.py
View file @
c2e19c97
...
...
@@ -14,6 +14,7 @@
"""Contains definitions of instance prediction heads."""
from
typing
import
List
,
Union
,
Optional
# Import libraries
import
tensorflow
as
tf
...
...
@@ -24,20 +25,21 @@ from official.modeling import tf_utils
class
DetectionHead
(
tf
.
keras
.
layers
.
Layer
):
"""Creates a detection head."""
def
__init__
(
self
,
num_classes
,
num_convs
=
0
,
num_filters
=
256
,
use_separable_conv
=
False
,
num_fcs
=
2
,
fc_dims
=
1024
,
activation
=
'relu'
,
use_sync_bn
=
False
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
**
kwargs
):
def
__init__
(
self
,
num_classes
:
int
,
num_convs
:
int
=
0
,
num_filters
:
int
=
256
,
use_separable_conv
:
bool
=
False
,
num_fcs
:
int
=
2
,
fc_dims
:
int
=
1024
,
activation
:
str
=
'relu'
,
use_sync_bn
:
bool
=
False
,
norm_momentum
:
float
=
0.99
,
norm_epsilon
:
float
=
0.001
,
kernel_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
,
bias_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
,
**
kwargs
):
"""Initializes a detection head.
Args:
...
...
@@ -85,7 +87,7 @@ class DetectionHead(tf.keras.layers.Layer):
self
.
_bn_axis
=
1
self
.
_activation
=
tf_utils
.
get_activation
(
activation
)
def
build
(
self
,
input_shape
):
def
build
(
self
,
input_shape
:
Union
[
tf
.
TensorShape
,
List
[
tf
.
TensorShape
]]
):
"""Creates the variables of the head."""
conv_op
=
(
tf
.
keras
.
layers
.
SeparableConv2D
if
self
.
_config_dict
[
'use_separable_conv'
]
...
...
@@ -163,7 +165,7 @@ class DetectionHead(tf.keras.layers.Layer):
super
(
DetectionHead
,
self
).
build
(
input_shape
)
def
call
(
self
,
inputs
,
training
=
None
):
def
call
(
self
,
inputs
:
tf
.
Tensor
,
training
:
bool
=
None
):
"""Forward pass of box and class branches for the Mask-RCNN model.
Args:
...
...
@@ -211,20 +213,21 @@ class DetectionHead(tf.keras.layers.Layer):
class
MaskHead
(
tf
.
keras
.
layers
.
Layer
):
"""Creates a mask head."""
def
__init__
(
self
,
num_classes
,
upsample_factor
=
2
,
num_convs
=
4
,
num_filters
=
256
,
use_separable_conv
=
False
,
activation
=
'relu'
,
use_sync_bn
=
False
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
class_agnostic
=
False
,
**
kwargs
):
def
__init__
(
self
,
num_classes
:
int
,
upsample_factor
:
int
=
2
,
num_convs
:
int
=
4
,
num_filters
:
int
=
256
,
use_separable_conv
:
bool
=
False
,
activation
:
str
=
'relu'
,
use_sync_bn
:
bool
=
False
,
norm_momentum
:
float
=
0.99
,
norm_epsilon
:
float
=
0.001
,
kernel_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
,
bias_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
,
class_agnostic
:
bool
=
False
,
**
kwargs
):
"""Initializes a mask head.
Args:
...
...
@@ -272,7 +275,7 @@ class MaskHead(tf.keras.layers.Layer):
self
.
_bn_axis
=
1
self
.
_activation
=
tf_utils
.
get_activation
(
activation
)
def
build
(
self
,
input_shape
):
def
build
(
self
,
input_shape
:
Union
[
tf
.
TensorShape
,
List
[
tf
.
TensorShape
]]
):
"""Creates the variables of the head."""
conv_op
=
(
tf
.
keras
.
layers
.
SeparableConv2D
if
self
.
_config_dict
[
'use_separable_conv'
]
...
...
@@ -364,7 +367,7 @@ class MaskHead(tf.keras.layers.Layer):
super
(
MaskHead
,
self
).
build
(
input_shape
)
def
call
(
self
,
inputs
,
training
=
None
):
def
call
(
self
,
inputs
:
List
[
tf
.
Tensor
]
,
training
:
bool
=
None
):
"""Forward pass of mask branch for the Mask-RCNN model.
Args:
...
...
official/vision/beta/modeling/heads/segmentation_heads.py
View file @
c2e19c97
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
"""Contains definitions of segmentation heads."""
from
typing
import
List
,
Union
,
Optional
,
Mapping
import
tensorflow
as
tf
from
official.modeling
import
tf_utils
...
...
@@ -25,23 +25,24 @@ from official.vision.beta.ops import spatial_transform_ops
class
SegmentationHead
(
tf
.
keras
.
layers
.
Layer
):
"""Creates a segmentation head."""
def
__init__
(
self
,
num_classes
,
level
,
num_convs
=
2
,
num_filters
=
256
,
prediction_kernel_size
=
1
,
upsample_factor
=
1
,
feature_fusion
=
None
,
low_level
=
2
,
low_level_num_filters
=
48
,
activation
=
'relu'
,
use_sync_bn
=
False
,
norm_momentum
=
0.99
,
norm_epsilon
=
0.001
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
**
kwargs
):
def
__init__
(
self
,
num_classes
:
int
,
level
:
Union
[
int
,
str
],
num_convs
:
int
=
2
,
num_filters
:
int
=
256
,
prediction_kernel_size
:
int
=
1
,
upsample_factor
:
int
=
1
,
feature_fusion
:
Optional
[
str
]
=
None
,
low_level
:
int
=
2
,
low_level_num_filters
:
int
=
48
,
activation
:
str
=
'relu'
,
use_sync_bn
:
bool
=
False
,
norm_momentum
:
float
=
0.99
,
norm_epsilon
:
float
=
0.001
,
kernel_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
,
bias_regularizer
:
Optional
[
tf
.
keras
.
regularizers
.
Regularizer
]
=
None
,
**
kwargs
):
"""Initializes a segmentation head.
Args:
...
...
@@ -101,7 +102,7 @@ class SegmentationHead(tf.keras.layers.Layer):
self
.
_bn_axis
=
1
self
.
_activation
=
tf_utils
.
get_activation
(
activation
)
def
build
(
self
,
input_shape
):
def
build
(
self
,
input_shape
:
Union
[
tf
.
TensorShape
,
List
[
tf
.
TensorShape
]]
):
"""Creates the variables of the segmentation head."""
conv_op
=
tf
.
keras
.
layers
.
Conv2D
conv_kwargs
=
{
...
...
@@ -159,7 +160,8 @@ class SegmentationHead(tf.keras.layers.Layer):
super
(
SegmentationHead
,
self
).
build
(
input_shape
)
def
call
(
self
,
backbone_output
,
decoder_output
):
def
call
(
self
,
backbone_output
:
Mapping
[
str
,
tf
.
Tensor
],
decoder_output
:
Mapping
[
str
,
tf
.
Tensor
]):
"""Forward pass of the segmentation head.
Args:
...
...
official/vision/beta/modeling/layers/box_sampler.py
View file @
c2e19c97
...
...
@@ -25,8 +25,8 @@ class BoxSampler(tf.keras.layers.Layer):
"""Creates a BoxSampler to sample positive and negative boxes."""
def
__init__
(
self
,
num_samples
=
512
,
foreground_fraction
=
0.25
,
num_samples
:
int
=
512
,
foreground_fraction
:
float
=
0.25
,
**
kwargs
):
"""Initializes a box sampler.
...
...
@@ -42,7 +42,8 @@ class BoxSampler(tf.keras.layers.Layer):
}
super
(
BoxSampler
,
self
).
__init__
(
**
kwargs
)
def
call
(
self
,
positive_matches
,
negative_matches
,
ignored_matches
):
def
call
(
self
,
positive_matches
:
tf
.
Tensor
,
negative_matches
:
tf
.
Tensor
,
ignored_matches
:
tf
.
Tensor
):
"""Samples and selects positive and negative instances.
Args:
...
...
official/vision/beta/modeling/layers/detection_generator.py
View file @
c2e19c97
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
"""Contains definitions of generators to generate the final detections."""
from
typing
import
Optional
,
Mapping
# Import libraries
import
tensorflow
as
tf
...
...
@@ -21,13 +21,14 @@ from official.vision.beta.ops import box_ops
from
official.vision.beta.ops
import
nms
def
_generate_detections_v1
(
boxes
,
scores
,
attributes
=
None
,
pre_nms_top_k
=
5000
,
pre_nms_score_threshold
=
0.05
,
nms_iou_threshold
=
0.5
,
max_num_detections
=
100
):
def
_generate_detections_v1
(
boxes
:
tf
.
Tensor
,
scores
:
tf
.
Tensor
,
attributes
:
Optional
[
Mapping
[
str
,
tf
.
Tensor
]]
=
None
,
pre_nms_top_k
:
int
=
5000
,
pre_nms_score_threshold
:
float
=
0.05
,
nms_iou_threshold
:
float
=
0.5
,
max_num_detections
:
int
=
100
):
"""Generates the final detections given the model outputs.
The implementation unrolls the batch dimension and process images one by one.
...
...
@@ -117,13 +118,14 @@ def _generate_detections_v1(boxes,
return
nmsed_boxes
,
nmsed_scores
,
nmsed_classes
,
valid_detections
,
nmsed_attributes
def
_generate_detections_per_image
(
boxes
,
scores
,
attributes
=
None
,
pre_nms_top_k
=
5000
,
pre_nms_score_threshold
=
0.05
,
nms_iou_threshold
=
0.5
,
max_num_detections
=
100
):
def
_generate_detections_per_image
(
boxes
:
tf
.
Tensor
,
scores
:
tf
.
Tensor
,
attributes
:
Optional
[
Mapping
[
str
,
tf
.
Tensor
]]
=
None
,
pre_nms_top_k
:
int
=
5000
,
pre_nms_score_threshold
:
float
=
0.05
,
nms_iou_threshold
:
float
=
0.5
,
max_num_detections
:
int
=
100
):
"""Generates the final detections per image given the model outputs.
Args:
...
...
@@ -225,7 +227,7 @@ def _generate_detections_per_image(boxes,
return
nmsed_boxes
,
nmsed_scores
,
nmsed_classes
,
valid_detections
,
nmsed_attributes
def
_select_top_k_scores
(
scores_in
,
pre_nms_num_detections
):
def
_select_top_k_scores
(
scores_in
:
tf
.
Tensor
,
pre_nms_num_detections
:
int
):
"""Selects top_k scores and indices for each class.
Args:
...
...
@@ -255,12 +257,12 @@ def _select_top_k_scores(scores_in, pre_nms_num_detections):
[
0
,
2
,
1
]),
tf
.
transpose
(
top_k_indices
,
[
0
,
2
,
1
])
def
_generate_detections_v2
(
boxes
,
scores
,
pre_nms_top_k
=
5000
,
pre_nms_score_threshold
=
0.05
,
nms_iou_threshold
=
0.5
,
max_num_detections
=
100
):
def
_generate_detections_v2
(
boxes
:
tf
.
Tensor
,
scores
:
tf
.
Tensor
,
pre_nms_top_k
:
int
=
5000
,
pre_nms_score_threshold
:
float
=
0.05
,
nms_iou_threshold
:
float
=
0.5
,
max_num_detections
:
int
=
100
):
"""Generates the final detections given the model outputs.
This implementation unrolls classes dimension while using the tf.while_loop
...
...
@@ -337,11 +339,10 @@ def _generate_detections_v2(boxes,
return
nmsed_boxes
,
nmsed_scores
,
nmsed_classes
,
valid_detections
def
_generate_detections_batched
(
boxes
,
scores
,
pre_nms_score_threshold
,
nms_iou_threshold
,
max_num_detections
):
def
_generate_detections_batched
(
boxes
:
tf
.
Tensor
,
scores
:
tf
.
Tensor
,
pre_nms_score_threshold
:
float
,
nms_iou_threshold
:
float
,
max_num_detections
:
int
):
"""Generates detected boxes with scores and classes for one-stage detector.
The function takes output of multi-level ConvNets and anchor boxes and
...
...
@@ -393,12 +394,12 @@ class DetectionGenerator(tf.keras.layers.Layer):
"""Generates the final detected boxes with scores and classes."""
def
__init__
(
self
,
apply_nms
=
True
,
pre_nms_top_k
=
5000
,
pre_nms_score_threshold
=
0.05
,
nms_iou_threshold
=
0.5
,
max_num_detections
=
100
,
use_batched_nms
=
False
,
apply_nms
:
bool
=
True
,
pre_nms_top_k
:
int
=
5000
,
pre_nms_score_threshold
:
float
=
0.05
,
nms_iou_threshold
:
float
=
0.5
,
max_num_detections
:
int
=
100
,
use_batched_nms
:
bool
=
False
,
**
kwargs
):
"""Initializes a detection generator.
...
...
@@ -427,11 +428,8 @@ class DetectionGenerator(tf.keras.layers.Layer):
}
super
(
DetectionGenerator
,
self
).
__init__
(
**
kwargs
)
def
__call__
(
self
,
raw_boxes
,
raw_scores
,
anchor_boxes
,
image_shape
):
def
__call__
(
self
,
raw_boxes
:
tf
.
Tensor
,
raw_scores
:
tf
.
Tensor
,
anchor_boxes
:
tf
.
Tensor
,
image_shape
:
tf
.
Tensor
):
"""Generates final detections.
Args:
...
...
@@ -546,12 +544,12 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
"""Generates detected boxes with scores and classes for one-stage detector."""
def
__init__
(
self
,
apply_nms
=
True
,
pre_nms_top_k
=
5000
,
pre_nms_score_threshold
=
0.05
,
nms_iou_threshold
=
0.5
,
max_num_detections
=
100
,
use_batched_nms
=
False
,
apply_nms
:
bool
=
True
,
pre_nms_top_k
:
int
=
5000
,
pre_nms_score_threshold
:
float
=
0.05
,
nms_iou_threshold
:
float
=
0.5
,
max_num_detections
:
int
=
100
,
use_batched_nms
:
bool
=
False
,
**
kwargs
):
"""Initializes a multi-level detection generator.
...
...
@@ -581,11 +579,11 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
super
(
MultilevelDetectionGenerator
,
self
).
__init__
(
**
kwargs
)
def
__call__
(
self
,
raw_boxes
,
raw_scores
,
anchor_boxes
,
image_shape
,
raw_attributes
=
None
):
raw_boxes
:
Mapping
[
str
,
tf
.
Tensor
]
,
raw_scores
:
Mapping
[
str
,
tf
.
Tensor
]
,
anchor_boxes
:
tf
.
Tensor
,
image_shape
:
tf
.
Tensor
,
raw_attributes
:
Mapping
[
str
,
tf
.
Tensor
]
=
None
):
"""Generates final detections.
Args:
...
...
@@ -600,11 +598,10 @@ class MultilevelDetectionGenerator(tf.keras.layers.Layer):
image_shape: A `tf.Tensor` of shape of [batch_size, 2] storing the image
height and width w.r.t. the scaled image, i.e. the same image space as
`box_outputs` and `anchor_boxes`.
raw_attributes: If not None, a `dict` of
(attribute_name, attribute_prediction) pairs. `attribute_prediction`
is a dict that contains keys representing FPN levels and values
representing tenors of shape `[batch, feature_h, feature_w,
num_anchors * attribute_size]`.
raw_attributes: If not None, a `dict` of (attribute_name,
attribute_prediction) pairs. `attribute_prediction` is a dict that
contains keys representing FPN levels and values representing tenors of
shape `[batch, feature_h, feature_w, num_anchors * attribute_size]`.
Returns:
If `apply_nms` = True, the return is a dictionary with keys:
...
...
official/vision/beta/modeling/layers/mask_sampler.py
View file @
c2e19c97
...
...
@@ -20,13 +20,13 @@ import tensorflow as tf
from
official.vision.beta.ops
import
spatial_transform_ops
def
_sample_and_crop_foreground_masks
(
candidate_rois
,
candidate_gt_boxes
,
candidate_gt_classes
,
candidate_gt_indices
,
gt_masks
,
num_sampled_masks
=
128
,
mask_target_size
=
28
):
def
_sample_and_crop_foreground_masks
(
candidate_rois
:
tf
.
Tensor
,
candidate_gt_boxes
:
tf
.
Tensor
,
candidate_gt_classes
:
tf
.
Tensor
,
candidate_gt_indices
:
tf
.
Tensor
,
gt_masks
:
tf
.
Tensor
,
num_sampled_masks
:
int
=
128
,
mask_target_size
:
int
=
28
):
"""Samples and creates cropped foreground masks for training.
Args:
...
...
@@ -104,22 +104,16 @@ def _sample_and_crop_foreground_masks(candidate_rois,
class
MaskSampler
(
tf
.
keras
.
layers
.
Layer
):
"""Samples and creates mask training targets."""
def
__init__
(
self
,
mask_target_size
,
num_sampled_masks
,
**
kwargs
):
def
__init__
(
self
,
mask_target_size
:
int
,
num_sampled_masks
:
int
,
**
kwargs
):
self
.
_config_dict
=
{
'mask_target_size'
:
mask_target_size
,
'num_sampled_masks'
:
num_sampled_masks
,
}
super
(
MaskSampler
,
self
).
__init__
(
**
kwargs
)
def
call
(
self
,
candidate_rois
,
candidate_gt_boxes
,
candidate_gt_classes
,
candidate_gt_indices
,
gt_masks
):
def
call
(
self
,
candidate_rois
:
tf
.
Tensor
,
candidate_gt_boxes
:
tf
.
Tensor
,
candidate_gt_classes
:
tf
.
Tensor
,
candidate_gt_indices
:
tf
.
Tensor
,
gt_masks
:
tf
.
Tensor
):
"""Samples and creates mask targets for training.
Args:
...
...
official/vision/beta/modeling/layers/roi_aligner.py
View file @
c2e19c97
...
...
@@ -14,6 +14,7 @@
"""Contains definitions of ROI aligner."""
from
typing
import
Mapping
import
tensorflow
as
tf
from
official.vision.beta.ops
import
spatial_transform_ops
...
...
@@ -23,10 +24,7 @@ from official.vision.beta.ops import spatial_transform_ops
class
MultilevelROIAligner
(
tf
.
keras
.
layers
.
Layer
):
"""Performs ROIAlign for the second stage processing."""
def
__init__
(
self
,
crop_size
=
7
,
sample_offset
=
0.5
,
**
kwargs
):
def
__init__
(
self
,
crop_size
:
int
=
7
,
sample_offset
:
float
=
0.5
,
**
kwargs
):
"""Initializes a ROI aligner.
Args:
...
...
@@ -40,7 +38,10 @@ class MultilevelROIAligner(tf.keras.layers.Layer):
}
super
(
MultilevelROIAligner
,
self
).
__init__
(
**
kwargs
)
def
call
(
self
,
features
,
boxes
,
training
=
None
):
def
call
(
self
,
features
:
Mapping
[
str
,
tf
.
Tensor
],
boxes
:
tf
.
Tensor
,
training
:
bool
=
None
):
"""Generates ROIs.
Args:
...
...
official/vision/beta/modeling/layers/roi_generator.py
View file @
c2e19c97
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
"""Contains definitions of ROI generator."""
from
typing
import
Optional
,
Mapping
# Import libraries
import
tensorflow
as
tf
...
...
@@ -21,19 +21,19 @@ from official.vision.beta.ops import box_ops
from
official.vision.beta.ops
import
nms
def
_multilevel_propose_rois
(
raw_boxes
,
raw_scores
,
anchor_boxes
,
image_shape
,
pre_nms_top_k
=
2000
,
pre_nms_score_threshold
=
0.0
,
pre_nms_min_size_threshold
=
0.0
,
nms_iou_threshold
=
0.7
,
num_proposals
=
1000
,
use_batched_nms
=
False
,
decode_boxes
=
True
,
clip_boxes
=
True
,
apply_sigmoid_to_score
=
True
):
def
_multilevel_propose_rois
(
raw_boxes
:
Mapping
[
str
,
tf
.
Tensor
]
,
raw_scores
:
Mapping
[
str
,
tf
.
Tensor
]
,
anchor_boxes
:
Mapping
[
str
,
tf
.
Tensor
]
,
image_shape
:
tf
.
Tensor
,
pre_nms_top_k
:
int
=
2000
,
pre_nms_score_threshold
:
float
=
0.0
,
pre_nms_min_size_threshold
:
float
=
0.0
,
nms_iou_threshold
:
float
=
0.7
,
num_proposals
:
int
=
1000
,
use_batched_nms
:
bool
=
False
,
decode_boxes
:
bool
=
True
,
clip_boxes
:
bool
=
True
,
apply_sigmoid_to_score
:
bool
=
True
):
"""Proposes RoIs given a group of candidates from different FPN levels.
The following describes the steps:
...
...
@@ -181,17 +181,17 @@ class MultilevelROIGenerator(tf.keras.layers.Layer):
"""Proposes RoIs for the second stage processing."""
def
__init__
(
self
,
pre_nms_top_k
=
2000
,
pre_nms_score_threshold
=
0.0
,
pre_nms_min_size_threshold
=
0.0
,
nms_iou_threshold
=
0.7
,
num_proposals
=
1000
,
test_pre_nms_top_k
=
1000
,
test_pre_nms_score_threshold
=
0.0
,
test_pre_nms_min_size_threshold
=
0.0
,
test_nms_iou_threshold
=
0.7
,
test_num_proposals
=
1000
,
use_batched_nms
=
False
,
pre_nms_top_k
:
int
=
2000
,
pre_nms_score_threshold
:
float
=
0.0
,
pre_nms_min_size_threshold
:
float
=
0.0
,
nms_iou_threshold
:
float
=
0.7
,
num_proposals
:
int
=
1000
,
test_pre_nms_top_k
:
int
=
1000
,
test_pre_nms_score_threshold
:
float
=
0.0
,
test_pre_nms_min_size_threshold
:
float
=
0.0
,
test_nms_iou_threshold
:
float
=
0.7
,
test_num_proposals
:
int
=
1000
,
use_batched_nms
:
bool
=
False
,
**
kwargs
):
"""Initializes a ROI generator.
...
...
@@ -240,11 +240,11 @@ class MultilevelROIGenerator(tf.keras.layers.Layer):
super
(
MultilevelROIGenerator
,
self
).
__init__
(
**
kwargs
)
def
call
(
self
,
raw_boxes
,
raw_scores
,
anchor_boxes
,
image_shape
,
training
=
None
):
raw_boxes
:
Mapping
[
str
,
tf
.
Tensor
]
,
raw_scores
:
Mapping
[
str
,
tf
.
Tensor
]
,
anchor_boxes
:
Mapping
[
str
,
tf
.
Tensor
]
,
image_shape
:
tf
.
Tensor
,
training
:
Optional
[
bool
]
=
None
):
"""Proposes RoIs given a group of candidates from different FPN levels.
The following describes the steps:
...
...
official/vision/beta/modeling/layers/roi_sampler.py
View file @
c2e19c97
...
...
@@ -13,7 +13,6 @@
# limitations under the License.
"""Contains definitions of ROI sampler."""
# Import libraries
import
tensorflow
as
tf
...
...
@@ -26,12 +25,12 @@ class ROISampler(tf.keras.layers.Layer):
"""Samples ROIs and assigns targets to the sampled ROIs."""
def
__init__
(
self
,
mix_gt_boxes
=
True
,
num_sampled_rois
=
512
,
foreground_fraction
=
0.25
,
foreground_iou_threshold
=
0.5
,
background_iou_high_threshold
=
0.5
,
background_iou_low_threshold
=
0
,
mix_gt_boxes
:
bool
=
True
,
num_sampled_rois
:
int
=
512
,
foreground_fraction
:
float
=
0.25
,
foreground_iou_threshold
:
float
=
0.5
,
background_iou_high_threshold
:
float
=
0.5
,
background_iou_low_threshold
:
float
=
0
,
**
kwargs
):
"""Initializes a ROI sampler.
...
...
@@ -73,7 +72,7 @@ class ROISampler(tf.keras.layers.Layer):
num_sampled_rois
,
foreground_fraction
)
super
(
ROISampler
,
self
).
__init__
(
**
kwargs
)
def
call
(
self
,
boxes
,
gt_boxes
,
gt_classes
):
def
call
(
self
,
boxes
:
tf
.
Tensor
,
gt_boxes
:
tf
.
Tensor
,
gt_classes
:
tf
.
Tensor
):
"""Assigns the proposals with groundtruth classes and performs subsmpling.
Given `proposed_boxes`, `gt_boxes`, and `gt_classes`, the function uses the
...
...
official/vision/beta/tasks/image_classification.py
View file @
c2e19c97
...
...
@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Image classification task definition."""
from
typing
import
Any
,
Optional
,
List
,
Tuple
from
absl
import
logging
import
tensorflow
as
tf
...
...
@@ -51,7 +51,7 @@ class ImageClassificationTask(base_task.Task):
return
model
def
initialize
(
self
,
model
:
tf
.
keras
.
Model
):
"""Load
ing
pretrained checkpoint."""
"""Load
s
pretrained checkpoint."""
if
not
self
.
task_config
.
init_checkpoint
:
return
...
...
@@ -75,7 +75,9 @@ class ImageClassificationTask(base_task.Task):
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
def
build_inputs
(
self
,
params
,
input_context
=
None
):
def
build_inputs
(
self
,
params
:
exp_cfg
.
DataConfig
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
"""Builds classification input."""
num_classes
=
self
.
task_config
.
model
.
num_classes
...
...
@@ -112,13 +114,16 @@ class ImageClassificationTask(base_task.Task):
return
dataset
def
build_losses
(
self
,
labels
,
model_outputs
,
aux_losses
=
None
):
"""Sparse categorical cross entropy loss.
def
build_losses
(
self
,
labels
:
tf
.
Tensor
,
model_outputs
:
tf
.
Tensor
,
aux_losses
:
Optional
[
Any
]
=
None
):
"""Builds sparse categorical cross entropy loss.
Args:
labels: labels.
labels:
Input groundtruth
labels.
model_outputs: Output logits of the classifier.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.
aux_losses:
The
auxiliarly loss tensors, i.e. `losses` in
tf.
keras.Model.
Returns:
The total loss tensor.
...
...
@@ -140,7 +145,7 @@ class ImageClassificationTask(base_task.Task):
return
total_loss
def
build_metrics
(
self
,
training
=
True
):
def
build_metrics
(
self
,
training
:
bool
=
True
):
"""Gets streaming metrics for training/validation."""
k
=
self
.
task_config
.
evaluation
.
top_k
if
self
.
task_config
.
losses
.
one_hot
:
...
...
@@ -155,14 +160,18 @@ class ImageClassificationTask(base_task.Task):
k
=
k
,
name
=
'top_{}_accuracy'
.
format
(
k
))]
return
metrics
def
train_step
(
self
,
inputs
,
model
,
optimizer
,
metrics
=
None
):
def
train_step
(
self
,
inputs
:
Tuple
[
Any
,
Any
],
model
:
tf
.
keras
.
Model
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
metrics
:
Optional
[
List
[
Any
]]
=
None
):
"""Does forward and backward.
Args:
inputs:
a dictionary
of input tensors.
model:
the model, forward pass definition
.
optimizer:
t
he optimizer for this training step.
metrics:
a
nested structure of metrics objects.
inputs:
A tuple of
of input tensors
of (features, labels)
.
model:
A tf.keras.Model instance
.
optimizer:
T
he optimizer for this training step.
metrics:
A
nested structure of metrics objects.
Returns:
A dictionary of logs.
...
...
@@ -209,13 +218,16 @@ class ImageClassificationTask(base_task.Task):
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
model
.
metrics
})
return
logs
def
validation_step
(
self
,
inputs
,
model
,
metrics
=
None
):
"""Validatation step.
def
validation_step
(
self
,
inputs
:
Tuple
[
Any
,
Any
],
model
:
tf
.
keras
.
Model
,
metrics
:
Optional
[
List
[
Any
]]
=
None
):
"""Runs validatation step.
Args:
inputs:
a dictionary
of input tensors.
model:
the
keras.Model.
metrics:
a
nested structure of metrics objects.
inputs:
A tuple of
of input tensors
of (features, labels)
.
model:
A tf.
keras.Model
instance
.
metrics:
A
nested structure of metrics objects.
Returns:
A dictionary of logs.
...
...
@@ -237,6 +249,6 @@ class ImageClassificationTask(base_task.Task):
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
model
.
metrics
})
return
logs
def
inference_step
(
self
,
inputs
,
m
odel
):
def
inference_step
(
self
,
inputs
:
tf
.
Tensor
,
model
:
tf
.
keras
.
M
odel
):
"""Performs the forward step."""
return
model
(
inputs
,
training
=
False
)
official/vision/beta/tasks/maskrcnn.py
View file @
c2e19c97
...
...
@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""RetinaNet task definition."""
from
typing
import
Any
,
Optional
,
List
,
Tuple
,
Mapping
from
absl
import
logging
import
tensorflow
as
tf
...
...
@@ -30,7 +30,8 @@ from official.vision.beta.losses import maskrcnn_losses
from
official.vision.beta.modeling
import
factory
def
zero_out_disallowed_class_ids
(
batch_class_ids
,
allowed_class_ids
):
def
zero_out_disallowed_class_ids
(
batch_class_ids
:
tf
.
Tensor
,
allowed_class_ids
:
List
[
int
]):
"""Zero out IDs of classes not in allowed_class_ids.
Args:
...
...
@@ -106,7 +107,9 @@ class MaskRCNNTask(base_task.Task):
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
def
build_inputs
(
self
,
params
,
input_context
=
None
):
def
build_inputs
(
self
,
params
:
exp_cfg
.
DataConfig
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
"""Build input dataset."""
decoder_cfg
=
params
.
decoder
.
get
()
if
params
.
decoder
.
type
==
'simple_decoder'
:
...
...
@@ -152,7 +155,10 @@ class MaskRCNNTask(base_task.Task):
return
dataset
def
build_losses
(
self
,
outputs
,
labels
,
aux_losses
=
None
):
def
build_losses
(
self
,
outputs
:
Mapping
[
str
,
Any
],
labels
:
Mapping
[
str
,
Any
],
aux_losses
:
Optional
[
Any
]
=
None
):
"""Build Mask R-CNN losses."""
params
=
self
.
task_config
...
...
@@ -218,7 +224,7 @@ class MaskRCNNTask(base_task.Task):
}
return
losses
def
build_metrics
(
self
,
training
=
True
):
def
build_metrics
(
self
,
training
:
bool
=
True
):
"""Build detection metrics."""
metrics
=
[]
if
training
:
...
...
@@ -242,7 +248,11 @@ class MaskRCNNTask(base_task.Task):
return
metrics
def
train_step
(
self
,
inputs
,
model
,
optimizer
,
metrics
=
None
):
def
train_step
(
self
,
inputs
:
Tuple
[
Any
,
Any
],
model
:
tf
.
keras
.
Model
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
metrics
:
Optional
[
List
[
Any
]]
=
None
):
"""Does forward and backward.
Args:
...
...
@@ -294,7 +304,10 @@ class MaskRCNNTask(base_task.Task):
return
logs
def
validation_step
(
self
,
inputs
,
model
,
metrics
=
None
):
def
validation_step
(
self
,
inputs
:
Tuple
[
Any
,
Any
],
model
:
tf
.
keras
.
Model
,
metrics
:
Optional
[
List
[
Any
]]
=
None
):
"""Validatation step.
Args:
...
...
official/vision/beta/tasks/retinanet.py
View file @
c2e19c97
...
...
@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""RetinaNet task definition."""
from
typing
import
Any
,
Optional
,
List
,
Tuple
,
Mapping
from
absl
import
logging
import
tensorflow
as
tf
...
...
@@ -84,7 +84,9 @@ class RetinaNetTask(base_task.Task):
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
def
build_inputs
(
self
,
params
,
input_context
=
None
):
def
build_inputs
(
self
,
params
:
exp_cfg
.
DataConfig
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
"""Build input dataset."""
if
params
.
tfds_name
:
...
...
@@ -131,7 +133,10 @@ class RetinaNetTask(base_task.Task):
return
dataset
def
build_losses
(
self
,
outputs
,
labels
,
aux_losses
=
None
):
def
build_losses
(
self
,
outputs
:
Mapping
[
str
,
Any
],
labels
:
Mapping
[
str
,
Any
],
aux_losses
:
Optional
[
Any
]
=
None
):
"""Build RetinaNet losses."""
params
=
self
.
task_config
cls_loss_fn
=
keras_cv
.
losses
.
FocalLoss
(
...
...
@@ -172,7 +177,7 @@ class RetinaNetTask(base_task.Task):
return
total_loss
,
cls_loss
,
box_loss
,
model_loss
def
build_metrics
(
self
,
training
=
True
):
def
build_metrics
(
self
,
training
:
bool
=
True
):
"""Build detection metrics."""
metrics
=
[]
metric_names
=
[
'total_loss'
,
'cls_loss'
,
'box_loss'
,
'model_loss'
]
...
...
@@ -190,7 +195,11 @@ class RetinaNetTask(base_task.Task):
return
metrics
def
train_step
(
self
,
inputs
,
model
,
optimizer
,
metrics
=
None
):
def
train_step
(
self
,
inputs
:
Tuple
[
Any
,
Any
],
model
:
tf
.
keras
.
Model
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
metrics
:
Optional
[
List
[
Any
]]
=
None
):
"""Does forward and backward.
Args:
...
...
@@ -241,7 +250,10 @@ class RetinaNetTask(base_task.Task):
return
logs
def
validation_step
(
self
,
inputs
,
model
,
metrics
=
None
):
def
validation_step
(
self
,
inputs
:
Tuple
[
Any
,
Any
],
model
:
tf
.
keras
.
Model
,
metrics
:
Optional
[
List
[
Any
]]
=
None
):
"""Validatation step.
Args:
...
...
official/vision/beta/tasks/semantic_segmentation.py
View file @
c2e19c97
...
...
@@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Image segmentation task definition."""
from
typing
import
Any
,
Optional
,
List
,
Tuple
,
Mapping
,
Union
from
absl
import
logging
import
tensorflow
as
tf
...
...
@@ -79,7 +79,9 @@ class SemanticSegmentationTask(base_task.Task):
logging
.
info
(
'Finished loading pretrained checkpoint from %s'
,
ckpt_dir_or_file
)
def
build_inputs
(
self
,
params
,
input_context
=
None
):
def
build_inputs
(
self
,
params
:
exp_cfg
.
DataConfig
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
"""Builds classification input."""
ignore_label
=
self
.
task_config
.
losses
.
ignore_label
...
...
@@ -114,7 +116,10 @@ class SemanticSegmentationTask(base_task.Task):
return
dataset
def
build_losses
(
self
,
labels
,
model_outputs
,
aux_losses
=
None
):
def
build_losses
(
self
,
labels
:
Mapping
[
str
,
tf
.
Tensor
],
model_outputs
:
Union
[
Mapping
[
str
,
tf
.
Tensor
],
tf
.
Tensor
],
aux_losses
:
Optional
[
Any
]
=
None
):
"""Segmentation loss.
Args:
...
...
@@ -140,7 +145,7 @@ class SemanticSegmentationTask(base_task.Task):
return
total_loss
def
build_metrics
(
self
,
training
=
True
):
def
build_metrics
(
self
,
training
:
bool
=
True
):
"""Gets streaming metrics for training/validation."""
metrics
=
[]
if
training
and
self
.
task_config
.
evaluation
.
report_train_mean_iou
:
...
...
@@ -159,7 +164,11 @@ class SemanticSegmentationTask(base_task.Task):
return
metrics
def
train_step
(
self
,
inputs
,
model
,
optimizer
,
metrics
=
None
):
def
train_step
(
self
,
inputs
:
Tuple
[
Any
,
Any
],
model
:
tf
.
keras
.
Model
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
metrics
:
Optional
[
List
[
Any
]]
=
None
):
"""Does forward and backward.
Args:
...
...
@@ -214,7 +223,10 @@ class SemanticSegmentationTask(base_task.Task):
return
logs
def
validation_step
(
self
,
inputs
,
model
,
metrics
=
None
):
def
validation_step
(
self
,
inputs
:
Tuple
[
Any
,
Any
],
model
:
tf
.
keras
.
Model
,
metrics
:
Optional
[
List
[
Any
]]
=
None
):
"""Validatation step.
Args:
...
...
@@ -251,7 +263,7 @@ class SemanticSegmentationTask(base_task.Task):
return
logs
def
inference_step
(
self
,
inputs
,
m
odel
):
def
inference_step
(
self
,
inputs
:
tf
.
Tensor
,
model
:
tf
.
keras
.
M
odel
):
"""Performs the forward step."""
return
model
(
inputs
,
training
=
False
)
...
...
official/vision/beta/tasks/video_classification.py
View file @
c2e19c97
...
...
@@ -12,8 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Video classification task definition."""
from
typing
import
Any
,
Optional
,
List
,
Tuple
from
absl
import
logging
import
tensorflow
as
tf
from
official.core
import
base_task
...
...
@@ -68,7 +69,9 @@ class VideoClassificationTask(base_task.Task):
tf
.
io
.
VarLenFeature
(
dtype
=
tf
.
float32
))
return
decoder
.
decode
def
build_inputs
(
self
,
params
:
exp_cfg
.
DataConfig
,
input_context
=
None
):
def
build_inputs
(
self
,
params
:
exp_cfg
.
DataConfig
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
):
"""Builds classification input."""
parser
=
video_input
.
Parser
(
input_params
=
params
)
...
...
@@ -85,7 +88,10 @@ class VideoClassificationTask(base_task.Task):
return
dataset
def
build_losses
(
self
,
labels
,
model_outputs
,
aux_losses
=
None
):
def
build_losses
(
self
,
labels
:
Any
,
model_outputs
:
Any
,
aux_losses
:
Optional
[
Any
]
=
None
):
"""Sparse categorical cross entropy loss.
Args:
...
...
@@ -132,7 +138,7 @@ class VideoClassificationTask(base_task.Task):
return
all_losses
def
build_metrics
(
self
,
training
=
True
):
def
build_metrics
(
self
,
training
:
bool
=
True
):
"""Gets streaming metrics for training/validation."""
if
self
.
task_config
.
losses
.
one_hot
:
metrics
=
[
...
...
@@ -168,7 +174,8 @@ class VideoClassificationTask(base_task.Task):
]
return
metrics
def
process_metrics
(
self
,
metrics
,
labels
,
model_outputs
):
def
process_metrics
(
self
,
metrics
:
List
[
Any
],
labels
:
Any
,
model_outputs
:
Any
):
"""Process and update metrics.
Called when using custom training loop API.
...
...
@@ -183,7 +190,11 @@ class VideoClassificationTask(base_task.Task):
for
metric
in
metrics
:
metric
.
update_state
(
labels
,
model_outputs
)
def
train_step
(
self
,
inputs
,
model
,
optimizer
,
metrics
=
None
):
def
train_step
(
self
,
inputs
:
Tuple
[
Any
,
Any
],
model
:
tf
.
keras
.
Model
,
optimizer
:
tf
.
keras
.
optimizers
.
Optimizer
,
metrics
:
Optional
[
List
[
Any
]]
=
None
):
"""Does forward and backward.
Args:
...
...
@@ -240,7 +251,10 @@ class VideoClassificationTask(base_task.Task):
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
model
.
metrics
})
return
logs
def
validation_step
(
self
,
inputs
,
model
,
metrics
=
None
):
def
validation_step
(
self
,
inputs
:
Tuple
[
Any
,
Any
],
model
:
tf
.
keras
.
Model
,
metrics
:
Optional
[
List
[
Any
]]
=
None
):
"""Validatation step.
Args:
...
...
@@ -266,7 +280,7 @@ class VideoClassificationTask(base_task.Task):
logs
.
update
({
m
.
name
:
m
.
result
()
for
m
in
model
.
metrics
})
return
logs
def
inference_step
(
self
,
features
,
m
odel
):
def
inference_step
(
self
,
features
:
tf
.
Tensor
,
model
:
tf
.
keras
.
M
odel
):
"""Performs the forward step."""
outputs
=
model
(
features
,
training
=
False
)
if
self
.
task_config
.
train_data
.
is_multilabel
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment