Commit 88253ce5 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 326286926
parent 52371ffe
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Utils used to manipulate tensor shapes.""" """Utils used to manipulate tensor shapes."""
import tensorflow as tf import tensorflow as tf
...@@ -42,7 +41,8 @@ def assert_shape_equal(shape_a, shape_b): ...@@ -42,7 +41,8 @@ def assert_shape_equal(shape_a, shape_b):
all(isinstance(dim, int) for dim in shape_b)): all(isinstance(dim, int) for dim in shape_b)):
if shape_a != shape_b: if shape_a != shape_b:
raise ValueError('Unequal shapes {}, {}'.format(shape_a, shape_b)) raise ValueError('Unequal shapes {}, {}'.format(shape_a, shape_b))
else: return tf.no_op() else:
return tf.no_op()
else: else:
return tf.assert_equal(shape_a, shape_b) return tf.assert_equal(shape_a, shape_b)
...@@ -87,9 +87,7 @@ def pad_or_clip_nd(tensor, output_shape): ...@@ -87,9 +87,7 @@ def pad_or_clip_nd(tensor, output_shape):
if shape is not None else -1 for i, shape in enumerate(output_shape) if shape is not None else -1 for i, shape in enumerate(output_shape)
] ]
clipped_tensor = tf.slice( clipped_tensor = tf.slice(
tensor, tensor, begin=tf.zeros(len(clip_size), dtype=tf.int32), size=clip_size)
begin=tf.zeros(len(clip_size), dtype=tf.int32),
size=clip_size)
# Pad tensor if the shape of clipped tensor is smaller than the expected # Pad tensor if the shape of clipped tensor is smaller than the expected
# shape. # shape.
...@@ -99,10 +97,7 @@ def pad_or_clip_nd(tensor, output_shape): ...@@ -99,10 +97,7 @@ def pad_or_clip_nd(tensor, output_shape):
for i, shape in enumerate(output_shape) for i, shape in enumerate(output_shape)
] ]
paddings = tf.stack( paddings = tf.stack(
[ [tf.zeros(len(trailing_paddings), dtype=tf.int32), trailing_paddings],
tf.zeros(len(trailing_paddings), dtype=tf.int32),
trailing_paddings
],
axis=1) axis=1)
padded_tensor = tf.pad(tensor=clipped_tensor, paddings=paddings) padded_tensor = tf.pad(tensor=clipped_tensor, paddings=paddings)
output_static_shape = [ output_static_shape = [
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Base target assigner module. """Base target assigner module.
The job of a TargetAssigner is, for a given set of anchors (bounding boxes) and The job of a TargetAssigner is, for a given set of anchors (bounding boxes) and
...@@ -31,35 +30,39 @@ Note that TargetAssigners only operate on detections from a single ...@@ -31,35 +30,39 @@ Note that TargetAssigners only operate on detections from a single
image at a time, so any logic for applying a TargetAssigner to multiple image at a time, so any logic for applying a TargetAssigner to multiple
images must be handled externally. images must be handled externally.
""" """
import tensorflow as tf import tensorflow as tf
from official.vision.detection.utils.object_detection import box_list from official.vision.detection.utils.object_detection import box_list
from official.vision.detection.utils.object_detection import shape_utils from official.vision.detection.utils.object_detection import shape_utils
KEYPOINTS_FIELD_NAME = 'keypoints' KEYPOINTS_FIELD_NAME = 'keypoints'
class TargetAssigner(object): class TargetAssigner(object):
"""Target assigner to compute classification and regression targets.""" """Target assigner to compute classification and regression targets."""
def __init__(self, similarity_calc, matcher, box_coder, def __init__(self,
negative_class_weight=1.0, unmatched_cls_target=None): similarity_calc,
matcher,
box_coder,
negative_class_weight=1.0,
unmatched_cls_target=None):
"""Construct Object Detection Target Assigner. """Construct Object Detection Target Assigner.
Args: Args:
similarity_calc: a RegionSimilarityCalculator similarity_calc: a RegionSimilarityCalculator
matcher: Matcher used to match groundtruth to anchors. matcher: Matcher used to match groundtruth to anchors.
box_coder: BoxCoder used to encode matching groundtruth boxes with box_coder: BoxCoder used to encode matching groundtruth boxes with respect
respect to anchors. to anchors.
negative_class_weight: classification weight to be associated to negative negative_class_weight: classification weight to be associated to negative
anchors (default: 1.0). The weight must be in [0., 1.]. anchors (default: 1.0). The weight must be in [0., 1.].
unmatched_cls_target: a float32 tensor with shape [d_1, d_2, ..., d_k] unmatched_cls_target: a float32 tensor with shape [d_1, d_2, ..., d_k]
which is consistent with the classification target for each which is consistent with the classification target for each anchor (and
anchor (and can be empty for scalar targets). This shape must thus be can be empty for scalar targets). This shape must thus be compatible
compatible with the groundtruth labels that are passed to the "assign" with the groundtruth labels that are passed to the "assign" function
function (which have shape [num_gt_boxes, d_1, d_2, ..., d_k]). (which have shape [num_gt_boxes, d_1, d_2, ..., d_k]). If set to None,
If set to None, unmatched_cls_target is set to be [0] for each anchor. unmatched_cls_target is set to be [0] for each anchor.
Raises: Raises:
ValueError: if similarity_calc is not a RegionSimilarityCalculator or ValueError: if similarity_calc is not a RegionSimilarityCalculator or
...@@ -78,8 +81,12 @@ class TargetAssigner(object): ...@@ -78,8 +81,12 @@ class TargetAssigner(object):
def box_coder(self): def box_coder(self):
return self._box_coder return self._box_coder
def assign(self, anchors, groundtruth_boxes, groundtruth_labels=None, def assign(self,
groundtruth_weights=None, **params): anchors,
groundtruth_boxes,
groundtruth_labels=None,
groundtruth_weights=None,
**params):
"""Assign classification and regression targets to each anchor. """Assign classification and regression targets to each anchor.
For a given set of anchors and groundtruth detections, match anchors For a given set of anchors and groundtruth detections, match anchors
...@@ -93,16 +100,16 @@ class TargetAssigner(object): ...@@ -93,16 +100,16 @@ class TargetAssigner(object):
Args: Args:
anchors: a BoxList representing N anchors anchors: a BoxList representing N anchors
groundtruth_boxes: a BoxList representing M groundtruth boxes groundtruth_boxes: a BoxList representing M groundtruth boxes
groundtruth_labels: a tensor of shape [M, d_1, ... d_k] groundtruth_labels: a tensor of shape [M, d_1, ... d_k] with labels for
with labels for each of the ground_truth boxes. The subshape each of the ground_truth boxes. The subshape [d_1, ... d_k] can be empty
[d_1, ... d_k] can be empty (corresponding to scalar inputs). When set (corresponding to scalar inputs). When set to None, groundtruth_labels
to None, groundtruth_labels assumes a binary problem where all assumes a binary problem where all ground_truth boxes get a positive
ground_truth boxes get a positive label (of 1). label (of 1).
groundtruth_weights: a float tensor of shape [M] indicating the weight to groundtruth_weights: a float tensor of shape [M] indicating the weight to
assign to all anchors match to a particular groundtruth box. The weights assign to all anchors match to a particular groundtruth box. The weights
must be in [0., 1.]. If None, all weights are set to 1. must be in [0., 1.]. If None, all weights are set to 1.
**params: Additional keyword arguments for specific implementations of **params: Additional keyword arguments for specific implementations of the
the Matcher. Matcher.
Returns: Returns:
cls_targets: a float32 tensor with shape [num_anchors, d_1, d_2 ... d_k], cls_targets: a float32 tensor with shape [num_anchors, d_1, d_2 ... d_k],
...@@ -125,16 +132,15 @@ class TargetAssigner(object): ...@@ -125,16 +132,15 @@ class TargetAssigner(object):
raise ValueError('groundtruth_boxes must be an BoxList') raise ValueError('groundtruth_boxes must be an BoxList')
if groundtruth_labels is None: if groundtruth_labels is None:
groundtruth_labels = tf.ones(tf.expand_dims(groundtruth_boxes.num_boxes(), groundtruth_labels = tf.ones(
0)) tf.expand_dims(groundtruth_boxes.num_boxes(), 0))
groundtruth_labels = tf.expand_dims(groundtruth_labels, -1) groundtruth_labels = tf.expand_dims(groundtruth_labels, -1)
unmatched_shape_assert = shape_utils.assert_shape_equal( unmatched_shape_assert = shape_utils.assert_shape_equal(
shape_utils.combined_static_and_dynamic_shape(groundtruth_labels)[1:], shape_utils.combined_static_and_dynamic_shape(groundtruth_labels)[1:],
shape_utils.combined_static_and_dynamic_shape( shape_utils.combined_static_and_dynamic_shape(
self._unmatched_cls_target)) self._unmatched_cls_target))
labels_and_box_shapes_assert = shape_utils.assert_shape_equal( labels_and_box_shapes_assert = shape_utils.assert_shape_equal(
shape_utils.combined_static_and_dynamic_shape( shape_utils.combined_static_and_dynamic_shape(groundtruth_labels)[:1],
groundtruth_labels)[:1],
shape_utils.combined_static_and_dynamic_shape( shape_utils.combined_static_and_dynamic_shape(
groundtruth_boxes.get())[:1]) groundtruth_boxes.get())[:1])
...@@ -145,11 +151,10 @@ class TargetAssigner(object): ...@@ -145,11 +151,10 @@ class TargetAssigner(object):
groundtruth_weights = tf.ones([num_gt_boxes], dtype=tf.float32) groundtruth_weights = tf.ones([num_gt_boxes], dtype=tf.float32)
with tf.control_dependencies( with tf.control_dependencies(
[unmatched_shape_assert, labels_and_box_shapes_assert]): [unmatched_shape_assert, labels_and_box_shapes_assert]):
match_quality_matrix = self._similarity_calc.compare(groundtruth_boxes, match_quality_matrix = self._similarity_calc.compare(
anchors) groundtruth_boxes, anchors)
match = self._matcher.match(match_quality_matrix, **params) match = self._matcher.match(match_quality_matrix, **params)
reg_targets = self._create_regression_targets(anchors, reg_targets = self._create_regression_targets(anchors, groundtruth_boxes,
groundtruth_boxes,
match) match)
cls_targets = self._create_classification_targets(groundtruth_labels, cls_targets = self._create_classification_targets(groundtruth_labels,
match) match)
...@@ -210,8 +215,8 @@ class TargetAssigner(object): ...@@ -210,8 +215,8 @@ class TargetAssigner(object):
match.match_results) match.match_results)
# Zero out the unmatched and ignored regression targets. # Zero out the unmatched and ignored regression targets.
unmatched_ignored_reg_targets = tf.tile( unmatched_ignored_reg_targets = tf.tile(self._default_regression_target(),
self._default_regression_target(), [match_results_shape[0], 1]) [match_results_shape[0], 1])
matched_anchors_mask = match.matched_column_indicator() matched_anchors_mask = match.matched_column_indicator()
# To broadcast matched_anchors_mask to the same shape as # To broadcast matched_anchors_mask to the same shape as
# matched_reg_targets. # matched_reg_targets.
...@@ -233,7 +238,7 @@ class TargetAssigner(object): ...@@ -233,7 +238,7 @@ class TargetAssigner(object):
Returns: Returns:
default_target: a float32 tensor with shape [1, box_code_dimension] default_target: a float32 tensor with shape [1, box_code_dimension]
""" """
return tf.constant([self._box_coder.code_size*[0]], tf.float32) return tf.constant([self._box_coder.code_size * [0]], tf.float32)
def _create_classification_targets(self, groundtruth_labels, match): def _create_classification_targets(self, groundtruth_labels, match):
"""Create classification targets for each anchor. """Create classification targets for each anchor.
...@@ -243,11 +248,11 @@ class TargetAssigner(object): ...@@ -243,11 +248,11 @@ class TargetAssigner(object):
to anything are given the target self._unmatched_cls_target to anything are given the target self._unmatched_cls_target
Args: Args:
groundtruth_labels: a tensor of shape [num_gt_boxes, d_1, ... d_k] groundtruth_labels: a tensor of shape [num_gt_boxes, d_1, ... d_k] with
with labels for each of the ground_truth boxes. The subshape labels for each of the ground_truth boxes. The subshape [d_1, ... d_k]
[d_1, ... d_k] can be empty (corresponding to scalar labels). can be empty (corresponding to scalar labels).
match: a matcher.Match object that provides a matching between anchors match: a matcher.Match object that provides a matching between anchors and
and groundtruth boxes. groundtruth boxes.
Returns: Returns:
a float32 tensor with shape [num_anchors, d_1, d_2 ... d_k], where the a float32 tensor with shape [num_anchors, d_1, d_2 ... d_k], where the
...@@ -267,8 +272,8 @@ class TargetAssigner(object): ...@@ -267,8 +272,8 @@ class TargetAssigner(object):
negative anchor. negative anchor.
Args: Args:
match: a matcher.Match object that provides a matching between anchors match: a matcher.Match object that provides a matching between anchors and
and groundtruth boxes. groundtruth boxes.
groundtruth_weights: a float tensor of shape [M] indicating the weight to groundtruth_weights: a float tensor of shape [M] indicating the weight to
assign to all anchors match to a particular groundtruth box. assign to all anchors match to a particular groundtruth box.
...@@ -278,9 +283,7 @@ class TargetAssigner(object): ...@@ -278,9 +283,7 @@ class TargetAssigner(object):
return match.gather_based_on_match( return match.gather_based_on_match(
groundtruth_weights, ignored_value=0., unmatched_value=0.) groundtruth_weights, ignored_value=0., unmatched_value=0.)
def _create_classification_weights(self, def _create_classification_weights(self, match, groundtruth_weights):
match,
groundtruth_weights):
"""Create classification weights for each anchor. """Create classification weights for each anchor.
Positive (matched) anchors are associated with a weight of Positive (matched) anchors are associated with a weight of
...@@ -291,8 +294,8 @@ class TargetAssigner(object): ...@@ -291,8 +294,8 @@ class TargetAssigner(object):
the case in object detection). the case in object detection).
Args: Args:
match: a matcher.Match object that provides a matching between anchors match: a matcher.Match object that provides a matching between anchors and
and groundtruth boxes. groundtruth boxes.
groundtruth_weights: a float tensor of shape [M] indicating the weight to groundtruth_weights: a float tensor of shape [M] indicating the weight to
assign to all anchors match to a particular groundtruth box. assign to all anchors match to a particular groundtruth box.
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""A set of functions that are used for visualization. """A set of functions that are used for visualization.
These functions often receive an image, perform some visualization on the image. These functions often receive an image, perform some visualization on the image.
...@@ -21,9 +20,11 @@ The functions do not return a value, instead they modify the image itself. ...@@ -21,9 +20,11 @@ The functions do not return a value, instead they modify the image itself.
""" """
import collections import collections
import functools import functools
from absl import logging from absl import logging
# Set headless-friendly backend. # Set headless-friendly backend.
import matplotlib; matplotlib.use('Agg') # pylint: disable=multiple-statements import matplotlib
matplotlib.use('Agg') # pylint: disable=multiple-statements
import matplotlib.pyplot as plt # pylint: disable=g-import-not-at-top import matplotlib.pyplot as plt # pylint: disable=g-import-not-at-top
import numpy as np import numpy as np
import PIL.Image as Image import PIL.Image as Image
...@@ -36,7 +37,6 @@ import tensorflow as tf ...@@ -36,7 +37,6 @@ import tensorflow as tf
from official.vision.detection.utils import box_utils from official.vision.detection.utils import box_utils
from official.vision.detection.utils.object_detection import shape_utils from official.vision.detection.utils.object_detection import shape_utils
_TITLE_LEFT_MARGIN = 10 _TITLE_LEFT_MARGIN = 10
_TITLE_TOP_MARGIN = 10 _TITLE_TOP_MARGIN = 10
STANDARD_COLORS = [ STANDARD_COLORS = [
...@@ -99,9 +99,9 @@ def visualize_images_with_bounding_boxes(images, box_outputs, step, ...@@ -99,9 +99,9 @@ def visualize_images_with_bounding_boxes(images, box_outputs, step,
summary_writer): summary_writer):
"""Records subset of evaluation images with bounding boxes.""" """Records subset of evaluation images with bounding boxes."""
if not isinstance(images, list): if not isinstance(images, list):
logging.warning('visualize_images_with_bounding_boxes expects list of ' logging.warning(
'images but received type: %s and value: %s', 'visualize_images_with_bounding_boxes expects list of '
type(images), images) 'images but received type: %s and value: %s', type(images), images)
return return
image_shape = tf.shape(images[0]) image_shape = tf.shape(images[0])
...@@ -140,11 +140,11 @@ def draw_bounding_box_on_image_array(image, ...@@ -140,11 +140,11 @@ def draw_bounding_box_on_image_array(image,
xmax: xmax of bounding box. xmax: xmax of bounding box.
color: color to draw bounding box. Default is red. color: color to draw bounding box. Default is red.
thickness: line thickness. Default value is 4. thickness: line thickness. Default value is 4.
display_str_list: list of strings to display in box display_str_list: list of strings to display in box (each to be shown on its
(each to be shown on its own line). own line).
use_normalized_coordinates: If True (default), treat coordinates use_normalized_coordinates: If True (default), treat coordinates ymin, xmin,
ymin, xmin, ymax, xmax as relative to the image. Otherwise treat ymax, xmax as relative to the image. Otherwise treat coordinates as
coordinates as absolute. absolute.
""" """
image_pil = Image.fromarray(np.uint8(image)).convert('RGB') image_pil = Image.fromarray(np.uint8(image)).convert('RGB')
draw_bounding_box_on_image(image_pil, ymin, xmin, ymax, xmax, color, draw_bounding_box_on_image(image_pil, ymin, xmin, ymax, xmax, color,
...@@ -180,11 +180,11 @@ def draw_bounding_box_on_image(image, ...@@ -180,11 +180,11 @@ def draw_bounding_box_on_image(image,
xmax: xmax of bounding box. xmax: xmax of bounding box.
color: color to draw bounding box. Default is red. color: color to draw bounding box. Default is red.
thickness: line thickness. Default value is 4. thickness: line thickness. Default value is 4.
display_str_list: list of strings to display in box display_str_list: list of strings to display in box (each to be shown on its
(each to be shown on its own line). own line).
use_normalized_coordinates: If True (default), treat coordinates use_normalized_coordinates: If True (default), treat coordinates ymin, xmin,
ymin, xmin, ymax, xmax as relative to the image. Otherwise treat ymax, xmax as relative to the image. Otherwise treat coordinates as
coordinates as absolute. absolute.
""" """
draw = ImageDraw.Draw(image) draw = ImageDraw.Draw(image)
im_width, im_height = image.size im_width, im_height = image.size
...@@ -193,8 +193,10 @@ def draw_bounding_box_on_image(image, ...@@ -193,8 +193,10 @@ def draw_bounding_box_on_image(image,
ymin * im_height, ymax * im_height) ymin * im_height, ymax * im_height)
else: else:
(left, right, top, bottom) = (xmin, xmax, ymin, ymax) (left, right, top, bottom) = (xmin, xmax, ymin, ymax)
draw.line([(left, top), (left, bottom), (right, bottom), draw.line([(left, top), (left, bottom), (right, bottom), (right, top),
(right, top), (left, top)], width=thickness, fill=color) (left, top)],
width=thickness,
fill=color)
try: try:
font = ImageFont.truetype('arial.ttf', 24) font = ImageFont.truetype('arial.ttf', 24)
except IOError: except IOError:
...@@ -215,15 +217,13 @@ def draw_bounding_box_on_image(image, ...@@ -215,15 +217,13 @@ def draw_bounding_box_on_image(image,
for display_str in display_str_list[::-1]: for display_str in display_str_list[::-1]:
text_width, text_height = font.getsize(display_str) text_width, text_height = font.getsize(display_str)
margin = np.ceil(0.05 * text_height) margin = np.ceil(0.05 * text_height)
draw.rectangle( draw.rectangle([(left, text_bottom - text_height - 2 * margin),
[(left, text_bottom - text_height - 2 * margin), (left + text_width, (left + text_width, text_bottom)],
text_bottom)], fill=color)
fill=color) draw.text((left + margin, text_bottom - text_height - margin),
draw.text( display_str,
(left + margin, text_bottom - text_height - margin), fill='black',
display_str, font=font)
fill='black',
font=font)
text_bottom -= text_height - 2 * margin text_bottom -= text_height - 2 * margin
...@@ -236,15 +236,13 @@ def draw_bounding_boxes_on_image_array(image, ...@@ -236,15 +236,13 @@ def draw_bounding_boxes_on_image_array(image,
Args: Args:
image: a numpy array object. image: a numpy array object.
boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax). boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax). The
The coordinates are in normalized format between [0, 1]. coordinates are in normalized format between [0, 1].
color: color to draw bounding box. Default is red. color: color to draw bounding box. Default is red.
thickness: line thickness. Default value is 4. thickness: line thickness. Default value is 4.
display_str_list_list: list of list of strings. display_str_list_list: list of list of strings. a list of strings for each
a list of strings for each bounding box. bounding box. The reason to pass a list of strings for a bounding box is
The reason to pass a list of strings for a that it might contain multiple labels.
bounding box is that it might contain
multiple labels.
Raises: Raises:
ValueError: if boxes is not a [N, 4] array ValueError: if boxes is not a [N, 4] array
...@@ -264,15 +262,13 @@ def draw_bounding_boxes_on_image(image, ...@@ -264,15 +262,13 @@ def draw_bounding_boxes_on_image(image,
Args: Args:
image: a PIL.Image object. image: a PIL.Image object.
boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax). boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax). The
The coordinates are in normalized format between [0, 1]. coordinates are in normalized format between [0, 1].
color: color to draw bounding box. Default is red. color: color to draw bounding box. Default is red.
thickness: line thickness. Default value is 4. thickness: line thickness. Default value is 4.
display_str_list_list: list of list of strings. display_str_list_list: list of list of strings. a list of strings for each
a list of strings for each bounding box. bounding box. The reason to pass a list of strings for a bounding box is
The reason to pass a list of strings for a that it might contain multiple labels.
bounding box is that it might contain
multiple labels.
Raises: Raises:
ValueError: if boxes is not a [N, 4] array ValueError: if boxes is not a [N, 4] array
...@@ -319,8 +315,9 @@ def _visualize_boxes_and_keypoints(image, boxes, classes, scores, keypoints, ...@@ -319,8 +315,9 @@ def _visualize_boxes_and_keypoints(image, boxes, classes, scores, keypoints,
**kwargs) **kwargs)
def _visualize_boxes_and_masks_and_keypoints( def _visualize_boxes_and_masks_and_keypoints(image, boxes, classes, scores,
image, boxes, classes, scores, masks, keypoints, category_index, **kwargs): masks, keypoints, category_index,
**kwargs):
return visualize_boxes_and_labels_on_image_array( return visualize_boxes_and_labels_on_image_array(
image, image,
boxes, boxes,
...@@ -374,8 +371,8 @@ def draw_bounding_boxes_on_image_tensors(images, ...@@ -374,8 +371,8 @@ def draw_bounding_boxes_on_image_tensors(images,
max_boxes_to_draw: Maximum number of boxes to draw on an image. Default 20. max_boxes_to_draw: Maximum number of boxes to draw on an image. Default 20.
min_score_thresh: Minimum score threshold for visualization. Default 0.2. min_score_thresh: Minimum score threshold for visualization. Default 0.2.
use_normalized_coordinates: Whether to assume boxes and kepoints are in use_normalized_coordinates: Whether to assume boxes and kepoints are in
normalized coordinates (as opposed to absolute coordiantes). normalized coordinates (as opposed to absolute coordiantes). Default is
Default is True. True.
Returns: Returns:
4D image tensor of type uint8, with boxes drawn on top. 4D image tensor of type uint8, with boxes drawn on top.
...@@ -432,17 +429,15 @@ def draw_bounding_boxes_on_image_tensors(images, ...@@ -432,17 +429,15 @@ def draw_bounding_boxes_on_image_tensors(images,
_visualize_boxes, _visualize_boxes,
category_index=category_index, category_index=category_index,
**visualization_keyword_args) **visualization_keyword_args)
elems = [ elems = [true_shapes, original_shapes, images, boxes, classes, scores]
true_shapes, original_shapes, images, boxes, classes, scores
]
def draw_boxes(image_and_detections): def draw_boxes(image_and_detections):
"""Draws boxes on image.""" """Draws boxes on image."""
true_shape = image_and_detections[0] true_shape = image_and_detections[0]
original_shape = image_and_detections[1] original_shape = image_and_detections[1]
if true_image_shape is not None: if true_image_shape is not None:
image = shape_utils.pad_or_clip_nd( image = shape_utils.pad_or_clip_nd(image_and_detections[2],
image_and_detections[2], [true_shape[0], true_shape[1], 3]) [true_shape[0], true_shape[1], 3])
if original_image_spatial_shape is not None: if original_image_spatial_shape is not None:
image_and_detections[2] = _resize_original_image(image, original_shape) image_and_detections[2] = _resize_original_image(image, original_shape)
...@@ -500,7 +495,8 @@ def draw_keypoints_on_image(image, ...@@ -500,7 +495,8 @@ def draw_keypoints_on_image(image,
for keypoint_x, keypoint_y in zip(keypoints_x, keypoints_y): for keypoint_x, keypoint_y in zip(keypoints_x, keypoints_y):
draw.ellipse([(keypoint_x - radius, keypoint_y - radius), draw.ellipse([(keypoint_x - radius, keypoint_y - radius),
(keypoint_x + radius, keypoint_y + radius)], (keypoint_x + radius, keypoint_y + radius)],
outline=color, fill=color) outline=color,
fill=color)
def draw_mask_on_image_array(image, mask, color='red', alpha=0.4): def draw_mask_on_image_array(image, mask, color='red', alpha=0.4):
...@@ -508,8 +504,8 @@ def draw_mask_on_image_array(image, mask, color='red', alpha=0.4): ...@@ -508,8 +504,8 @@ def draw_mask_on_image_array(image, mask, color='red', alpha=0.4):
Args: Args:
image: uint8 numpy array with shape (img_height, img_height, 3) image: uint8 numpy array with shape (img_height, img_height, 3)
mask: a uint8 numpy array of shape (img_height, img_height) with mask: a uint8 numpy array of shape (img_height, img_height) with values
values between either 0 or 1. between either 0 or 1.
color: color to draw the keypoints with. Default is red. color: color to draw the keypoints with. Default is red.
alpha: transparency value between 0 and 1. (default: 0.4) alpha: transparency value between 0 and 1. (default: 0.4)
...@@ -531,7 +527,7 @@ def draw_mask_on_image_array(image, mask, color='red', alpha=0.4): ...@@ -531,7 +527,7 @@ def draw_mask_on_image_array(image, mask, color='red', alpha=0.4):
solid_color = np.expand_dims( solid_color = np.expand_dims(
np.ones_like(mask), axis=2) * np.reshape(list(rgb), [1, 1, 3]) np.ones_like(mask), axis=2) * np.reshape(list(rgb), [1, 1, 3])
pil_solid_color = Image.fromarray(np.uint8(solid_color)).convert('RGBA') pil_solid_color = Image.fromarray(np.uint8(solid_color)).convert('RGBA')
pil_mask = Image.fromarray(np.uint8(255.0*alpha*mask)).convert('L') pil_mask = Image.fromarray(np.uint8(255.0 * alpha * mask)).convert('L')
pil_image = Image.composite(pil_solid_color, pil_image, pil_mask) pil_image = Image.composite(pil_solid_color, pil_image, pil_mask)
np.copyto(image, np.array(pil_image.convert('RGB'))) np.copyto(image, np.array(pil_image.convert('RGB')))
...@@ -565,21 +561,20 @@ def visualize_boxes_and_labels_on_image_array( ...@@ -565,21 +561,20 @@ def visualize_boxes_and_labels_on_image_array(
boxes: a numpy array of shape [N, 4] boxes: a numpy array of shape [N, 4]
classes: a numpy array of shape [N]. Note that class indices are 1-based, classes: a numpy array of shape [N]. Note that class indices are 1-based,
and match the keys in the label map. and match the keys in the label map.
scores: a numpy array of shape [N] or None. If scores=None, then scores: a numpy array of shape [N] or None. If scores=None, then this
this function assumes that the boxes to be plotted are groundtruth function assumes that the boxes to be plotted are groundtruth boxes and
boxes and plot all boxes as black with no classes or scores. plot all boxes as black with no classes or scores.
category_index: a dict containing category dictionaries (each holding category_index: a dict containing category dictionaries (each holding
category index `id` and category name `name`) keyed by category indices. category index `id` and category name `name`) keyed by category indices.
instance_masks: a numpy array of shape [N, image_height, image_width] with instance_masks: a numpy array of shape [N, image_height, image_width] with
values ranging between 0 and 1, can be None. values ranging between 0 and 1, can be None.
instance_boundaries: a numpy array of shape [N, image_height, image_width] instance_boundaries: a numpy array of shape [N, image_height, image_width]
with values ranging between 0 and 1, can be None. with values ranging between 0 and 1, can be None.
keypoints: a numpy array of shape [N, num_keypoints, 2], can keypoints: a numpy array of shape [N, num_keypoints, 2], can be None
be None use_normalized_coordinates: whether boxes is to be interpreted as normalized
use_normalized_coordinates: whether boxes is to be interpreted as coordinates or not.
normalized coordinates or not. max_boxes_to_draw: maximum number of boxes to visualize. If None, draw all
max_boxes_to_draw: maximum number of boxes to visualize. If None, draw boxes.
all boxes.
min_score_thresh: minimum score threshold for a box to be visualized min_score_thresh: minimum score threshold for a box to be visualized
agnostic_mode: boolean (default: False) controlling whether to evaluate in agnostic_mode: boolean (default: False) controlling whether to evaluate in
class-agnostic mode or not. This mode will display scores but ignore class-agnostic mode or not. This mode will display scores but ignore
...@@ -624,32 +619,25 @@ def visualize_boxes_and_labels_on_image_array( ...@@ -624,32 +619,25 @@ def visualize_boxes_and_labels_on_image_array(
display_str = str(class_name) display_str = str(class_name)
if not skip_scores: if not skip_scores:
if not display_str: if not display_str:
display_str = '{}%'.format(int(100*scores[i])) display_str = '{}%'.format(int(100 * scores[i]))
else: else:
display_str = '{}: {}%'.format(display_str, int(100*scores[i])) display_str = '{}: {}%'.format(display_str, int(100 * scores[i]))
box_to_display_str_map[box].append(display_str) box_to_display_str_map[box].append(display_str)
if agnostic_mode: if agnostic_mode:
box_to_color_map[box] = 'DarkOrange' box_to_color_map[box] = 'DarkOrange'
else: else:
box_to_color_map[box] = STANDARD_COLORS[ box_to_color_map[box] = STANDARD_COLORS[classes[i] %
classes[i] % len(STANDARD_COLORS)] len(STANDARD_COLORS)]
# Draw all boxes onto image. # Draw all boxes onto image.
for box, color in box_to_color_map.items(): for box, color in box_to_color_map.items():
ymin, xmin, ymax, xmax = box ymin, xmin, ymax, xmax = box
if instance_masks is not None: if instance_masks is not None:
draw_mask_on_image_array( draw_mask_on_image_array(
image, image, box_to_instance_masks_map[box], color=color)
box_to_instance_masks_map[box],
color=color
)
if instance_boundaries is not None: if instance_boundaries is not None:
draw_mask_on_image_array( draw_mask_on_image_array(
image, image, box_to_instance_boundaries_map[box], color='red', alpha=1.0)
box_to_instance_boundaries_map[box],
color='red',
alpha=1.0
)
draw_bounding_box_on_image_array( draw_bounding_box_on_image_array(
image, image,
ymin, ymin,
...@@ -681,13 +669,15 @@ def add_cdf_image_summary(values, name): ...@@ -681,13 +669,15 @@ def add_cdf_image_summary(values, name):
values: a 1-D float32 tensor containing the values. values: a 1-D float32 tensor containing the values.
name: name for the image summary. name: name for the image summary.
""" """
def cdf_plot(values): def cdf_plot(values):
"""Numpy function to plot CDF.""" """Numpy function to plot CDF."""
normalized_values = values / np.sum(values) normalized_values = values / np.sum(values)
sorted_values = np.sort(normalized_values) sorted_values = np.sort(normalized_values)
cumulative_values = np.cumsum(sorted_values) cumulative_values = np.cumsum(sorted_values)
fraction_of_examples = (np.arange(cumulative_values.size, dtype=np.float32) fraction_of_examples = (
/ cumulative_values.size) np.arange(cumulative_values.size, dtype=np.float32) /
cumulative_values.size)
fig = plt.figure(frameon=False) fig = plt.figure(frameon=False)
ax = fig.add_subplot('111') ax = fig.add_subplot('111')
ax.plot(fraction_of_examples, cumulative_values) ax.plot(fraction_of_examples, cumulative_values)
...@@ -695,8 +685,9 @@ def add_cdf_image_summary(values, name): ...@@ -695,8 +685,9 @@ def add_cdf_image_summary(values, name):
ax.set_xlabel('fraction of examples') ax.set_xlabel('fraction of examples')
fig.canvas.draw() fig.canvas.draw()
width, height = fig.get_size_inches() * fig.get_dpi() width, height = fig.get_size_inches() * fig.get_dpi()
image = np.fromstring(fig.canvas.tostring_rgb(), dtype='uint8').reshape( image = np.fromstring(
1, int(height), int(width), 3) fig.canvas.tostring_rgb(),
dtype='uint8').reshape(1, int(height), int(width), 3)
return image return image
cdf_plot = tf.compat.v1.py_func(cdf_plot, [values], tf.uint8) cdf_plot = tf.compat.v1.py_func(cdf_plot, [values], tf.uint8)
...@@ -725,8 +716,8 @@ def add_hist_image_summary(values, bins, name): ...@@ -725,8 +716,8 @@ def add_hist_image_summary(values, bins, name):
fig.canvas.draw() fig.canvas.draw()
width, height = fig.get_size_inches() * fig.get_dpi() width, height = fig.get_size_inches() * fig.get_dpi()
image = np.fromstring( image = np.fromstring(
fig.canvas.tostring_rgb(), dtype='uint8').reshape( fig.canvas.tostring_rgb(),
1, int(height), int(width), 3) dtype='uint8').reshape(1, int(height), int(width), 3)
return image return image
hist_plot = tf.compat.v1.py_func(hist_plot, [values, bins], tf.uint8) hist_plot = tf.compat.v1.py_func(hist_plot, [values, bins], tf.uint8)
......
...@@ -24,6 +24,7 @@ from __future__ import division ...@@ -24,6 +24,7 @@ from __future__ import division
from __future__ import print_function from __future__ import print_function
import math import math
import tensorflow as tf import tensorflow as tf
from typing import Any, Dict, List, Optional, Text, Tuple from typing import Any, Dict, List, Optional, Text, Tuple
...@@ -120,10 +121,8 @@ def _convert_translation_to_transform(translations: tf.Tensor) -> tf.Tensor: ...@@ -120,10 +121,8 @@ def _convert_translation_to_transform(translations: tf.Tensor) -> tf.Tensor:
) )
def _convert_angles_to_transform( def _convert_angles_to_transform(angles: tf.Tensor, image_width: tf.Tensor,
angles: tf.Tensor, image_height: tf.Tensor) -> tf.Tensor:
image_width: tf.Tensor,
image_height: tf.Tensor) -> tf.Tensor:
"""Converts an angle or angles to a projective transform. """Converts an angle or angles to a projective transform.
Args: Args:
...@@ -173,9 +172,7 @@ def transform(image: tf.Tensor, transforms) -> tf.Tensor: ...@@ -173,9 +172,7 @@ def transform(image: tf.Tensor, transforms) -> tf.Tensor:
transforms = transforms[None] transforms = transforms[None]
image = to_4d(image) image = to_4d(image)
image = image_ops.transform( image = image_ops.transform(
images=image, images=image, transforms=transforms, interpolation='nearest')
transforms=transforms,
interpolation='nearest')
return from_4d(image, original_ndims) return from_4d(image, original_ndims)
...@@ -216,9 +213,8 @@ def rotate(image: tf.Tensor, degrees: float) -> tf.Tensor: ...@@ -216,9 +213,8 @@ def rotate(image: tf.Tensor, degrees: float) -> tf.Tensor:
image_height = tf.cast(tf.shape(image)[1], tf.float32) image_height = tf.cast(tf.shape(image)[1], tf.float32)
image_width = tf.cast(tf.shape(image)[2], tf.float32) image_width = tf.cast(tf.shape(image)[2], tf.float32)
transforms = _convert_angles_to_transform(angles=radians, transforms = _convert_angles_to_transform(
image_width=image_width, angles=radians, image_width=image_width, image_height=image_height)
image_height=image_height)
# In practice, we should randomize the rotation degrees by flipping # In practice, we should randomize the rotation degrees by flipping
# it negatively half the time, but that's done on 'degrees' outside # it negatively half the time, but that's done on 'degrees' outside
# of the function. # of the function.
...@@ -279,11 +275,10 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor: ...@@ -279,11 +275,10 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
Args: Args:
image: An image Tensor of type uint8. image: An image Tensor of type uint8.
pad_size: Specifies how big the zero mask that will be generated is that pad_size: Specifies how big the zero mask that will be generated is that is
is applied to the image. The mask will be of size applied to the image. The mask will be of size (2*pad_size x 2*pad_size).
(2*pad_size x 2*pad_size). replace: What pixel value to fill in the image in the area that has the
replace: What pixel value to fill in the image in the area that has cutout mask applied to it.
the cutout mask applied to it.
Returns: Returns:
An image Tensor that is of type uint8. An image Tensor that is of type uint8.
...@@ -293,30 +288,30 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor: ...@@ -293,30 +288,30 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
# Sample the center location in the image where the zero mask will be applied. # Sample the center location in the image where the zero mask will be applied.
cutout_center_height = tf.random.uniform( cutout_center_height = tf.random.uniform(
shape=[], minval=0, maxval=image_height, shape=[], minval=0, maxval=image_height, dtype=tf.int32)
dtype=tf.int32)
cutout_center_width = tf.random.uniform( cutout_center_width = tf.random.uniform(
shape=[], minval=0, maxval=image_width, shape=[], minval=0, maxval=image_width, dtype=tf.int32)
dtype=tf.int32)
lower_pad = tf.maximum(0, cutout_center_height - pad_size) lower_pad = tf.maximum(0, cutout_center_height - pad_size)
upper_pad = tf.maximum(0, image_height - cutout_center_height - pad_size) upper_pad = tf.maximum(0, image_height - cutout_center_height - pad_size)
left_pad = tf.maximum(0, cutout_center_width - pad_size) left_pad = tf.maximum(0, cutout_center_width - pad_size)
right_pad = tf.maximum(0, image_width - cutout_center_width - pad_size) right_pad = tf.maximum(0, image_width - cutout_center_width - pad_size)
cutout_shape = [image_height - (lower_pad + upper_pad), cutout_shape = [
image_width - (left_pad + right_pad)] image_height - (lower_pad + upper_pad),
image_width - (left_pad + right_pad)
]
padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]] padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]]
mask = tf.pad( mask = tf.pad(
tf.zeros(cutout_shape, dtype=image.dtype), tf.zeros(cutout_shape, dtype=image.dtype),
padding_dims, constant_values=1) padding_dims,
constant_values=1)
mask = tf.expand_dims(mask, -1) mask = tf.expand_dims(mask, -1)
mask = tf.tile(mask, [1, 1, 3]) mask = tf.tile(mask, [1, 1, 3])
image = tf.where( image = tf.where(
tf.equal(mask, 0), tf.equal(mask, 0),
tf.ones_like(image, dtype=image.dtype) * replace, tf.ones_like(image, dtype=image.dtype) * replace, image)
image)
return image return image
...@@ -398,8 +393,8 @@ def shear_x(image: tf.Tensor, level: float, replace: int) -> tf.Tensor: ...@@ -398,8 +393,8 @@ def shear_x(image: tf.Tensor, level: float, replace: int) -> tf.Tensor:
# with a matrix form of: # with a matrix form of:
# [1 level # [1 level
# 0 1]. # 0 1].
image = transform(image=wrap(image), image = transform(
transforms=[1., level, 0., 0., 1., 0., 0., 0.]) image=wrap(image), transforms=[1., level, 0., 0., 1., 0., 0., 0.])
return unwrap(image, replace) return unwrap(image, replace)
...@@ -409,8 +404,8 @@ def shear_y(image: tf.Tensor, level: float, replace: int) -> tf.Tensor: ...@@ -409,8 +404,8 @@ def shear_y(image: tf.Tensor, level: float, replace: int) -> tf.Tensor:
# with a matrix form of: # with a matrix form of:
# [1 0 # [1 0
# level 1]. # level 1].
image = transform(image=wrap(image), image = transform(
transforms=[1., 0., 0., level, 1., 0., 0., 0.]) image=wrap(image), transforms=[1., 0., 0., level, 1., 0., 0., 0.])
return unwrap(image, replace) return unwrap(image, replace)
...@@ -460,9 +455,9 @@ def sharpness(image: tf.Tensor, factor: float) -> tf.Tensor: ...@@ -460,9 +455,9 @@ def sharpness(image: tf.Tensor, factor: float) -> tf.Tensor:
# Make image 4D for conv operation. # Make image 4D for conv operation.
image = tf.expand_dims(image, 0) image = tf.expand_dims(image, 0)
# SMOOTH PIL Kernel. # SMOOTH PIL Kernel.
kernel = tf.constant( kernel = tf.constant([[1, 1, 1], [1, 5, 1], [1, 1, 1]],
[[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=tf.float32, dtype=tf.float32,
shape=[3, 3, 1, 1]) / 13. shape=[3, 3, 1, 1]) / 13.
# Tile across channel dimension. # Tile across channel dimension.
kernel = tf.tile(kernel, [1, 1, 3, 1]) kernel = tf.tile(kernel, [1, 1, 3, 1])
strides = [1, 1, 1, 1] strides = [1, 1, 1, 1]
...@@ -484,6 +479,7 @@ def sharpness(image: tf.Tensor, factor: float) -> tf.Tensor: ...@@ -484,6 +479,7 @@ def sharpness(image: tf.Tensor, factor: float) -> tf.Tensor:
def equalize(image: tf.Tensor) -> tf.Tensor: def equalize(image: tf.Tensor) -> tf.Tensor:
"""Implements Equalize function from PIL using TF ops.""" """Implements Equalize function from PIL using TF ops."""
def scale_channel(im, c): def scale_channel(im, c):
"""Scale the data in the channel to implement equalize.""" """Scale the data in the channel to implement equalize."""
im = tf.cast(im[:, :, c], tf.int32) im = tf.cast(im[:, :, c], tf.int32)
...@@ -507,9 +503,9 @@ def equalize(image: tf.Tensor) -> tf.Tensor: ...@@ -507,9 +503,9 @@ def equalize(image: tf.Tensor) -> tf.Tensor:
# If step is zero, return the original image. Otherwise, build # If step is zero, return the original image. Otherwise, build
# lut from the full histogram and step and then index from it. # lut from the full histogram and step and then index from it.
result = tf.cond(tf.equal(step, 0), result = tf.cond(
lambda: im, tf.equal(step, 0), lambda: im,
lambda: tf.gather(build_lut(histo, step), im)) lambda: tf.gather(build_lut(histo, step), im))
return tf.cast(result, tf.uint8) return tf.cast(result, tf.uint8)
...@@ -582,7 +578,7 @@ def _randomly_negate_tensor(tensor): ...@@ -582,7 +578,7 @@ def _randomly_negate_tensor(tensor):
def _rotate_level_to_arg(level: float): def _rotate_level_to_arg(level: float):
level = (level/_MAX_LEVEL) * 30. level = (level / _MAX_LEVEL) * 30.
level = _randomly_negate_tensor(level) level = _randomly_negate_tensor(level)
return (level,) return (level,)
...@@ -597,18 +593,18 @@ def _shrink_level_to_arg(level: float): ...@@ -597,18 +593,18 @@ def _shrink_level_to_arg(level: float):
def _enhance_level_to_arg(level: float): def _enhance_level_to_arg(level: float):
return ((level/_MAX_LEVEL) * 1.8 + 0.1,) return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
def _shear_level_to_arg(level: float): def _shear_level_to_arg(level: float):
level = (level/_MAX_LEVEL) * 0.3 level = (level / _MAX_LEVEL) * 0.3
# Flip level to negative with 50% chance. # Flip level to negative with 50% chance.
level = _randomly_negate_tensor(level) level = _randomly_negate_tensor(level)
return (level,) return (level,)
def _translate_level_to_arg(level: float, translate_const: float): def _translate_level_to_arg(level: float, translate_const: float):
level = (level/_MAX_LEVEL) * float(translate_const) level = (level / _MAX_LEVEL) * float(translate_const)
# Flip level to negative with 50% chance. # Flip level to negative with 50% chance.
level = _randomly_negate_tensor(level) level = _randomly_negate_tensor(level)
return (level,) return (level,)
...@@ -618,20 +614,15 @@ def _mult_to_arg(level: float, multiplier: float = 1.): ...@@ -618,20 +614,15 @@ def _mult_to_arg(level: float, multiplier: float = 1.):
return (int((level / _MAX_LEVEL) * multiplier),) return (int((level / _MAX_LEVEL) * multiplier),)
def _apply_func_with_prob(func: Any, def _apply_func_with_prob(func: Any, image: tf.Tensor, args: Any, prob: float):
image: tf.Tensor,
args: Any,
prob: float):
"""Apply `func` to image w/ `args` as input with probability `prob`.""" """Apply `func` to image w/ `args` as input with probability `prob`."""
assert isinstance(args, tuple) assert isinstance(args, tuple)
# Apply the function with probability `prob`. # Apply the function with probability `prob`.
should_apply_op = tf.cast( should_apply_op = tf.cast(
tf.floor(tf.random.uniform([], dtype=tf.float32) + prob), tf.bool) tf.floor(tf.random.uniform([], dtype=tf.float32) + prob), tf.bool)
augmented_image = tf.cond( augmented_image = tf.cond(should_apply_op, lambda: func(image, *args),
should_apply_op, lambda: image)
lambda: func(image, *args),
lambda: image)
return augmented_image return augmented_image
...@@ -709,11 +700,8 @@ def level_to_arg(cutout_const: float, translate_const: float): ...@@ -709,11 +700,8 @@ def level_to_arg(cutout_const: float, translate_const: float):
return args return args
def _parse_policy_info(name: Text, def _parse_policy_info(name: Text, prob: float, level: float,
prob: float, replace_value: List[int], cutout_const: float,
level: float,
replace_value: List[int],
cutout_const: float,
translate_const: float) -> Tuple[Any, float, Any]: translate_const: float) -> Tuple[Any, float, Any]:
"""Return the function that corresponds to `name` and update `level` param.""" """Return the function that corresponds to `name` and update `level` param."""
func = NAME_TO_FUNC[name] func = NAME_TO_FUNC[name]
...@@ -969,8 +957,9 @@ class RandAugment(ImageAugment): ...@@ -969,8 +957,9 @@ class RandAugment(ImageAugment):
min_prob, max_prob = 0.2, 0.8 min_prob, max_prob = 0.2, 0.8
for _ in range(self.num_layers): for _ in range(self.num_layers):
op_to_select = tf.random.uniform( op_to_select = tf.random.uniform([],
[], maxval=len(self.available_ops) + 1, dtype=tf.int32) maxval=len(self.available_ops) + 1,
dtype=tf.int32)
branch_fns = [] branch_fns = []
for (i, op_name) in enumerate(self.available_ops): for (i, op_name) in enumerate(self.available_ops):
...@@ -978,11 +967,8 @@ class RandAugment(ImageAugment): ...@@ -978,11 +967,8 @@ class RandAugment(ImageAugment):
minval=min_prob, minval=min_prob,
maxval=max_prob, maxval=max_prob,
dtype=tf.float32) dtype=tf.float32)
func, _, args = _parse_policy_info(op_name, func, _, args = _parse_policy_info(op_name, prob, self.magnitude,
prob, replace_value, self.cutout_const,
self.magnitude,
replace_value,
self.cutout_const,
self.translate_const) self.translate_const)
branch_fns.append(( branch_fns.append((
i, i,
...@@ -991,9 +977,10 @@ class RandAugment(ImageAugment): ...@@ -991,9 +977,10 @@ class RandAugment(ImageAugment):
image, *selected_args))) image, *selected_args)))
# pylint:enable=g-long-lambda # pylint:enable=g-long-lambda
image = tf.switch_case(branch_index=op_to_select, image = tf.switch_case(
branch_fns=branch_fns, branch_index=op_to_select,
default=lambda: tf.identity(image)) branch_fns=branch_fns,
default=lambda: tf.identity(image))
image = tf.cast(image, dtype=input_image_type) image = tf.cast(image, dtype=input_image_type)
return image return image
...@@ -49,24 +49,15 @@ class TransformsTest(parameterized.TestCase, tf.test.TestCase): ...@@ -49,24 +49,15 @@ class TransformsTest(parameterized.TestCase, tf.test.TestCase):
def test_transform(self, dtype): def test_transform(self, dtype):
image = tf.constant([[1, 2], [3, 4]], dtype=dtype) image = tf.constant([[1, 2], [3, 4]], dtype=dtype)
self.assertAllEqual(augment.transform(image, transforms=[1]*8), self.assertAllEqual(
[[4, 4], [4, 4]]) augment.transform(image, transforms=[1] * 8), [[4, 4], [4, 4]])
def test_translate(self, dtype): def test_translate(self, dtype):
image = tf.constant( image = tf.constant(
[[1, 0, 1, 0], [[1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]], dtype=dtype)
[0, 1, 0, 1],
[1, 0, 1, 0],
[0, 1, 0, 1]],
dtype=dtype)
translations = [-1, -1] translations = [-1, -1]
translated = augment.translate(image=image, translated = augment.translate(image=image, translations=translations)
translations=translations) expected = [[1, 0, 1, 1], [0, 1, 0, 0], [1, 0, 1, 1], [1, 0, 1, 1]]
expected = [
[1, 0, 1, 1],
[0, 1, 0, 0],
[1, 0, 1, 1],
[1, 0, 1, 1]]
self.assertAllEqual(translated, expected) self.assertAllEqual(translated, expected)
def test_translate_shapes(self, dtype): def test_translate_shapes(self, dtype):
...@@ -85,9 +76,7 @@ class TransformsTest(parameterized.TestCase, tf.test.TestCase): ...@@ -85,9 +76,7 @@ class TransformsTest(parameterized.TestCase, tf.test.TestCase):
image = tf.reshape(tf.cast(tf.range(9), dtype), (3, 3)) image = tf.reshape(tf.cast(tf.range(9), dtype), (3, 3))
rotation = 90. rotation = 90.
transformed = augment.rotate(image=image, degrees=rotation) transformed = augment.rotate(image=image, degrees=rotation)
expected = [[2, 5, 8], expected = [[2, 5, 8], [1, 4, 7], [0, 3, 6]]
[1, 4, 7],
[0, 3, 6]]
self.assertAllEqual(transformed, expected) self.assertAllEqual(transformed, expected)
def test_rotate_shapes(self, dtype): def test_rotate_shapes(self, dtype):
...@@ -129,15 +118,13 @@ class AutoaugmentTest(tf.test.TestCase): ...@@ -129,15 +118,13 @@ class AutoaugmentTest(tf.test.TestCase):
image = tf.ones((224, 224, 3), dtype=tf.uint8) image = tf.ones((224, 224, 3), dtype=tf.uint8)
for op_name in augment.NAME_TO_FUNC: for op_name in augment.NAME_TO_FUNC:
func, _, args = augment._parse_policy_info(op_name, func, _, args = augment._parse_policy_info(op_name, prob, magnitude,
prob, replace_value, cutout_const,
magnitude,
replace_value,
cutout_const,
translate_const) translate_const)
image = func(image, *args) image = func(image, *args)
self.assertEqual((224, 224, 3), image.shape) self.assertEqual((224, 224, 3), image.shape)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -21,6 +21,7 @@ from __future__ import print_function ...@@ -21,6 +21,7 @@ from __future__ import print_function
import os import os
from typing import Any, List, MutableMapping, Text from typing import Any, List, MutableMapping, Text
from absl import logging from absl import logging
import tensorflow as tf import tensorflow as tf
...@@ -43,8 +44,9 @@ def get_callbacks(model_checkpoint: bool = True, ...@@ -43,8 +44,9 @@ def get_callbacks(model_checkpoint: bool = True,
callbacks = [] callbacks = []
if model_checkpoint: if model_checkpoint:
ckpt_full_path = os.path.join(model_dir, 'model.ckpt-{epoch:04d}') ckpt_full_path = os.path.join(model_dir, 'model.ckpt-{epoch:04d}')
callbacks.append(tf.keras.callbacks.ModelCheckpoint( callbacks.append(
ckpt_full_path, save_weights_only=True, verbose=1)) tf.keras.callbacks.ModelCheckpoint(
ckpt_full_path, save_weights_only=True, verbose=1))
if include_tensorboard: if include_tensorboard:
callbacks.append( callbacks.append(
CustomTensorBoard( CustomTensorBoard(
...@@ -61,13 +63,14 @@ def get_callbacks(model_checkpoint: bool = True, ...@@ -61,13 +63,14 @@ def get_callbacks(model_checkpoint: bool = True,
if apply_moving_average: if apply_moving_average:
# Save moving average model to a different file so that # Save moving average model to a different file so that
# we can resume training from a checkpoint # we can resume training from a checkpoint
ckpt_full_path = os.path.join( ckpt_full_path = os.path.join(model_dir, 'average',
model_dir, 'average', 'model.ckpt-{epoch:04d}') 'model.ckpt-{epoch:04d}')
callbacks.append(AverageModelCheckpoint( callbacks.append(
update_weights=False, AverageModelCheckpoint(
filepath=ckpt_full_path, update_weights=False,
save_weights_only=True, filepath=ckpt_full_path,
verbose=1)) save_weights_only=True,
verbose=1))
callbacks.append(MovingAverageCallback()) callbacks.append(MovingAverageCallback())
return callbacks return callbacks
...@@ -175,16 +178,13 @@ class MovingAverageCallback(tf.keras.callbacks.Callback): ...@@ -175,16 +178,13 @@ class MovingAverageCallback(tf.keras.callbacks.Callback):
**kwargs: Any additional callback arguments. **kwargs: Any additional callback arguments.
""" """
def __init__(self, def __init__(self, overwrite_weights_on_train_end: bool = False, **kwargs):
overwrite_weights_on_train_end: bool = False,
**kwargs):
super(MovingAverageCallback, self).__init__(**kwargs) super(MovingAverageCallback, self).__init__(**kwargs)
self.overwrite_weights_on_train_end = overwrite_weights_on_train_end self.overwrite_weights_on_train_end = overwrite_weights_on_train_end
def set_model(self, model: tf.keras.Model): def set_model(self, model: tf.keras.Model):
super(MovingAverageCallback, self).set_model(model) super(MovingAverageCallback, self).set_model(model)
assert isinstance(self.model.optimizer, assert isinstance(self.model.optimizer, optimizer_factory.MovingAverage)
optimizer_factory.MovingAverage)
self.model.optimizer.shadow_copy(self.model) self.model.optimizer.shadow_copy(self.model)
def on_test_begin(self, logs: MutableMapping[Text, Any] = None): def on_test_begin(self, logs: MutableMapping[Text, Any] = None):
...@@ -204,40 +204,30 @@ class AverageModelCheckpoint(tf.keras.callbacks.ModelCheckpoint): ...@@ -204,40 +204,30 @@ class AverageModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):
Taken from tfa.callbacks.AverageModelCheckpoint. Taken from tfa.callbacks.AverageModelCheckpoint.
Attributes: Attributes:
update_weights: If True, assign the moving average weights update_weights: If True, assign the moving average weights to the model, and
to the model, and save them. If False, keep the old save them. If False, keep the old non-averaged weights, but the saved
non-averaged weights, but the saved model uses the model uses the average weights. See `tf.keras.callbacks.ModelCheckpoint`
average weights. for the other args.
See `tf.keras.callbacks.ModelCheckpoint` for the other args.
""" """
def __init__( def __init__(self,
self, update_weights: bool,
update_weights: bool, filepath: str,
filepath: str, monitor: str = 'val_loss',
monitor: str = 'val_loss', verbose: int = 0,
verbose: int = 0, save_best_only: bool = False,
save_best_only: bool = False, save_weights_only: bool = False,
save_weights_only: bool = False, mode: str = 'auto',
mode: str = 'auto', save_freq: str = 'epoch',
save_freq: str = 'epoch', **kwargs):
**kwargs):
self.update_weights = update_weights self.update_weights = update_weights
super().__init__( super().__init__(filepath, monitor, verbose, save_best_only,
filepath, save_weights_only, mode, save_freq, **kwargs)
monitor,
verbose,
save_best_only,
save_weights_only,
mode,
save_freq,
**kwargs)
def set_model(self, model): def set_model(self, model):
if not isinstance(model.optimizer, optimizer_factory.MovingAverage): if not isinstance(model.optimizer, optimizer_factory.MovingAverage):
raise TypeError( raise TypeError('AverageModelCheckpoint is only used when training'
'AverageModelCheckpoint is only used when training' 'with MovingAverage')
'with MovingAverage')
return super().set_model(model) return super().set_model(model)
def _save_model(self, epoch, logs): def _save_model(self, epoch, logs):
......
...@@ -41,7 +41,7 @@ from official.vision.image_classification.resnet import resnet_model ...@@ -41,7 +41,7 @@ from official.vision.image_classification.resnet import resnet_model
def get_models() -> Mapping[str, tf.keras.Model]: def get_models() -> Mapping[str, tf.keras.Model]:
"""Returns the mapping from model type name to Keras model.""" """Returns the mapping from model type name to Keras model."""
return { return {
'efficientnet': efficientnet_model.EfficientNet.from_name, 'efficientnet': efficientnet_model.EfficientNet.from_name,
'resnet': resnet_model.resnet50, 'resnet': resnet_model.resnet50,
} }
...@@ -55,7 +55,7 @@ def get_dtype_map() -> Mapping[str, tf.dtypes.DType]: ...@@ -55,7 +55,7 @@ def get_dtype_map() -> Mapping[str, tf.dtypes.DType]:
'float16': tf.float16, 'float16': tf.float16,
'fp32': tf.float32, 'fp32': tf.float32,
'bf16': tf.bfloat16, 'bf16': tf.bfloat16,
} }
def _get_metrics(one_hot: bool) -> Mapping[Text, Any]: def _get_metrics(one_hot: bool) -> Mapping[Text, Any]:
...@@ -63,22 +63,28 @@ def _get_metrics(one_hot: bool) -> Mapping[Text, Any]: ...@@ -63,22 +63,28 @@ def _get_metrics(one_hot: bool) -> Mapping[Text, Any]:
if one_hot: if one_hot:
return { return {
# (name, metric_fn) # (name, metric_fn)
'acc': tf.keras.metrics.CategoricalAccuracy(name='accuracy'), 'acc':
'accuracy': tf.keras.metrics.CategoricalAccuracy(name='accuracy'), tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
'top_1': tf.keras.metrics.CategoricalAccuracy(name='accuracy'), 'accuracy':
'top_5': tf.keras.metrics.TopKCategoricalAccuracy( tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
k=5, 'top_1':
name='top_5_accuracy'), tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
'top_5':
tf.keras.metrics.TopKCategoricalAccuracy(
k=5, name='top_5_accuracy'),
} }
else: else:
return { return {
# (name, metric_fn) # (name, metric_fn)
'acc': tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'), 'acc':
'accuracy': tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'), tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
'top_1': tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'), 'accuracy':
'top_5': tf.keras.metrics.SparseTopKCategoricalAccuracy( tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
k=5, 'top_1':
name='top_5_accuracy'), tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
'top_5':
tf.keras.metrics.SparseTopKCategoricalAccuracy(
k=5, name='top_5_accuracy'),
} }
...@@ -94,8 +100,7 @@ def get_image_size_from_model( ...@@ -94,8 +100,7 @@ def get_image_size_from_model(
def _get_dataset_builders(params: base_configs.ExperimentConfig, def _get_dataset_builders(params: base_configs.ExperimentConfig,
strategy: tf.distribute.Strategy, strategy: tf.distribute.Strategy,
one_hot: bool one_hot: bool) -> Tuple[Any, Any]:
) -> Tuple[Any, Any]:
"""Create and return train and validation dataset builders.""" """Create and return train and validation dataset builders."""
if one_hot: if one_hot:
logging.warning('label_smoothing > 0, so datasets will be one hot encoded.') logging.warning('label_smoothing > 0, so datasets will be one hot encoded.')
...@@ -107,9 +112,7 @@ def _get_dataset_builders(params: base_configs.ExperimentConfig, ...@@ -107,9 +112,7 @@ def _get_dataset_builders(params: base_configs.ExperimentConfig,
image_size = get_image_size_from_model(params) image_size = get_image_size_from_model(params)
dataset_configs = [ dataset_configs = [params.train_dataset, params.validation_dataset]
params.train_dataset, params.validation_dataset
]
builders = [] builders = []
for config in dataset_configs: for config in dataset_configs:
...@@ -171,8 +174,7 @@ def _get_params_from_flags(flags_obj: flags.FlagValues): ...@@ -171,8 +174,7 @@ def _get_params_from_flags(flags_obj: flags.FlagValues):
}, },
} }
overriding_configs = (flags_obj.config_file, overriding_configs = (flags_obj.config_file, flags_obj.params_override,
flags_obj.params_override,
flags_overrides) flags_overrides)
pp = pprint.PrettyPrinter() pp = pprint.PrettyPrinter()
...@@ -190,8 +192,7 @@ def _get_params_from_flags(flags_obj: flags.FlagValues): ...@@ -190,8 +192,7 @@ def _get_params_from_flags(flags_obj: flags.FlagValues):
return params return params
def resume_from_checkpoint(model: tf.keras.Model, def resume_from_checkpoint(model: tf.keras.Model, model_dir: str,
model_dir: str,
train_steps: int) -> int: train_steps: int) -> int:
"""Resumes from the latest checkpoint, if possible. """Resumes from the latest checkpoint, if possible.
...@@ -226,8 +227,7 @@ def resume_from_checkpoint(model: tf.keras.Model, ...@@ -226,8 +227,7 @@ def resume_from_checkpoint(model: tf.keras.Model,
def initialize(params: base_configs.ExperimentConfig, def initialize(params: base_configs.ExperimentConfig,
dataset_builder: dataset_factory.DatasetBuilder): dataset_builder: dataset_factory.DatasetBuilder):
"""Initializes backend related initializations.""" """Initializes backend related initializations."""
keras_utils.set_session_config( keras_utils.set_session_config(enable_xla=params.runtime.enable_xla)
enable_xla=params.runtime.enable_xla)
performance.set_mixed_precision_policy(dataset_builder.dtype, performance.set_mixed_precision_policy(dataset_builder.dtype,
get_loss_scale(params)) get_loss_scale(params))
if tf.config.list_physical_devices('GPU'): if tf.config.list_physical_devices('GPU'):
...@@ -244,7 +244,8 @@ def initialize(params: base_configs.ExperimentConfig, ...@@ -244,7 +244,8 @@ def initialize(params: base_configs.ExperimentConfig,
per_gpu_thread_count=params.runtime.per_gpu_thread_count, per_gpu_thread_count=params.runtime.per_gpu_thread_count,
gpu_thread_mode=params.runtime.gpu_thread_mode, gpu_thread_mode=params.runtime.gpu_thread_mode,
num_gpus=params.runtime.num_gpus, num_gpus=params.runtime.num_gpus,
datasets_num_private_threads=params.runtime.dataset_num_private_threads) # pylint:disable=line-too-long datasets_num_private_threads=params.runtime
.dataset_num_private_threads) # pylint:disable=line-too-long
if params.runtime.batchnorm_spatial_persistent: if params.runtime.batchnorm_spatial_persistent:
os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '1' os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '1'
...@@ -253,9 +254,7 @@ def define_classifier_flags(): ...@@ -253,9 +254,7 @@ def define_classifier_flags():
"""Defines common flags for image classification.""" """Defines common flags for image classification."""
hyperparams_flags.initialize_common_flags() hyperparams_flags.initialize_common_flags()
flags.DEFINE_string( flags.DEFINE_string(
'data_dir', 'data_dir', default=None, help='The location of the input data.')
default=None,
help='The location of the input data.')
flags.DEFINE_string( flags.DEFINE_string(
'mode', 'mode',
default=None, default=None,
...@@ -278,8 +277,7 @@ def define_classifier_flags(): ...@@ -278,8 +277,7 @@ def define_classifier_flags():
help='The interval of steps between logging of batch level stats.') help='The interval of steps between logging of batch level stats.')
def serialize_config(params: base_configs.ExperimentConfig, def serialize_config(params: base_configs.ExperimentConfig, model_dir: str):
model_dir: str):
"""Serializes and saves the experiment config.""" """Serializes and saves the experiment config."""
params_save_path = os.path.join(model_dir, 'params.yaml') params_save_path = os.path.join(model_dir, 'params.yaml')
logging.info('Saving experiment configuration to %s', params_save_path) logging.info('Saving experiment configuration to %s', params_save_path)
...@@ -293,9 +291,8 @@ def train_and_eval( ...@@ -293,9 +291,8 @@ def train_and_eval(
"""Runs the train and eval path using compile/fit.""" """Runs the train and eval path using compile/fit."""
logging.info('Running train and eval.') logging.info('Running train and eval.')
distribution_utils.configure_cluster( distribution_utils.configure_cluster(params.runtime.worker_hosts,
params.runtime.worker_hosts, params.runtime.task_index)
params.runtime.task_index)
# Note: for TPUs, strategy and scope should be created before the dataset # Note: for TPUs, strategy and scope should be created before the dataset
strategy = strategy_override or distribution_utils.get_distribution_strategy( strategy = strategy_override or distribution_utils.get_distribution_strategy(
...@@ -313,8 +310,9 @@ def train_and_eval( ...@@ -313,8 +310,9 @@ def train_and_eval(
one_hot = label_smoothing and label_smoothing > 0 one_hot = label_smoothing and label_smoothing > 0
builders = _get_dataset_builders(params, strategy, one_hot) builders = _get_dataset_builders(params, strategy, one_hot)
datasets = [builder.build(strategy) datasets = [
if builder else None for builder in builders] builder.build(strategy) if builder else None for builder in builders
]
# Unpack datasets and builders based on train/val/test splits # Unpack datasets and builders based on train/val/test splits
train_builder, validation_builder = builders # pylint: disable=unbalanced-tuple-unpacking train_builder, validation_builder = builders # pylint: disable=unbalanced-tuple-unpacking
...@@ -351,16 +349,16 @@ def train_and_eval( ...@@ -351,16 +349,16 @@ def train_and_eval(
label_smoothing=params.model.loss.label_smoothing) label_smoothing=params.model.loss.label_smoothing)
else: else:
loss_obj = tf.keras.losses.SparseCategoricalCrossentropy() loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()
model.compile(optimizer=optimizer, model.compile(
loss=loss_obj, optimizer=optimizer,
metrics=metrics, loss=loss_obj,
experimental_steps_per_execution=steps_per_loop) metrics=metrics,
experimental_steps_per_execution=steps_per_loop)
initial_epoch = 0 initial_epoch = 0
if params.train.resume_checkpoint: if params.train.resume_checkpoint:
initial_epoch = resume_from_checkpoint(model=model, initial_epoch = resume_from_checkpoint(
model_dir=params.model_dir, model=model, model_dir=params.model_dir, train_steps=train_steps)
train_steps=train_steps)
callbacks = custom_callbacks.get_callbacks( callbacks = custom_callbacks.get_callbacks(
model_checkpoint=params.train.callbacks.enable_checkpoint_and_export, model_checkpoint=params.train.callbacks.enable_checkpoint_and_export,
...@@ -399,9 +397,7 @@ def train_and_eval( ...@@ -399,9 +397,7 @@ def train_and_eval(
validation_dataset, steps=validation_steps, verbose=2) validation_dataset, steps=validation_steps, verbose=2)
# TODO(dankondratyuk): eval and save final test accuracy # TODO(dankondratyuk): eval and save final test accuracy
stats = common.build_stats(history, stats = common.build_stats(history, validation_output, callbacks)
validation_output,
callbacks)
return stats return stats
......
...@@ -105,14 +105,13 @@ def get_trivial_model(num_classes: int) -> tf.keras.Model: ...@@ -105,14 +105,13 @@ def get_trivial_model(num_classes: int) -> tf.keras.Model:
lr = 0.01 lr = 0.01
optimizer = tf.keras.optimizers.SGD(learning_rate=lr) optimizer = tf.keras.optimizers.SGD(learning_rate=lr)
loss_obj = tf.keras.losses.SparseCategoricalCrossentropy() loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()
model.compile(optimizer=optimizer, model.compile(optimizer=optimizer, loss=loss_obj, run_eagerly=True)
loss=loss_obj,
run_eagerly=True)
return model return model
def get_trivial_data() -> tf.data.Dataset: def get_trivial_data() -> tf.data.Dataset:
"""Gets trivial data in the ImageNet size.""" """Gets trivial data in the ImageNet size."""
def generate_data(_) -> tf.data.Dataset: def generate_data(_) -> tf.data.Dataset:
image = tf.zeros(shape=(224, 224, 3), dtype=tf.float32) image = tf.zeros(shape=(224, 224, 3), dtype=tf.float32)
label = tf.zeros([1], dtype=tf.int32) label = tf.zeros([1], dtype=tf.int32)
...@@ -120,8 +119,8 @@ def get_trivial_data() -> tf.data.Dataset: ...@@ -120,8 +119,8 @@ def get_trivial_data() -> tf.data.Dataset:
dataset = tf.data.Dataset.range(1) dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat() dataset = dataset.repeat()
dataset = dataset.map(generate_data, dataset = dataset.map(
num_parallel_calls=tf.data.experimental.AUTOTUNE) generate_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.prefetch(buffer_size=1).batch(1) dataset = dataset.prefetch(buffer_size=1).batch(1)
return dataset return dataset
...@@ -165,11 +164,10 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase): ...@@ -165,11 +164,10 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
'--mode=train_and_eval', '--mode=train_and_eval',
] ]
run = functools.partial(classifier_trainer.run, run = functools.partial(
strategy_override=distribution) classifier_trainer.run, strategy_override=distribution)
run_end_to_end(main=run, run_end_to_end(
extra_flags=train_and_eval_flags, main=run, extra_flags=train_and_eval_flags, model_dir=model_dir)
model_dir=model_dir)
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
...@@ -209,29 +207,26 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase): ...@@ -209,29 +207,26 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
get_params_override(export_params) get_params_override(export_params)
] ]
run = functools.partial(classifier_trainer.run, run = functools.partial(
strategy_override=distribution) classifier_trainer.run, strategy_override=distribution)
run_end_to_end(main=run, run_end_to_end(
extra_flags=train_and_eval_flags, main=run, extra_flags=train_and_eval_flags, model_dir=model_dir)
model_dir=model_dir) run_end_to_end(main=run, extra_flags=export_flags, model_dir=model_dir)
run_end_to_end(main=run,
extra_flags=export_flags,
model_dir=model_dir)
self.assertTrue(os.path.exists(export_path)) self.assertTrue(os.path.exists(export_path))
@combinations.generate( @combinations.generate(
combinations.combine( combinations.combine(
distribution=[ distribution=[
strategy_combinations.tpu_strategy, strategy_combinations.tpu_strategy,
], ],
model=[ model=[
'efficientnet', 'efficientnet',
'resnet', 'resnet',
], ],
mode='eager', mode='eager',
dataset='imagenet', dataset='imagenet',
dtype='bfloat16', dtype='bfloat16',
)) ))
def test_tpu_train(self, distribution, model, dataset, dtype): def test_tpu_train(self, distribution, model, dataset, dtype):
"""Test train_and_eval and export for Keras classifier models.""" """Test train_and_eval and export for Keras classifier models."""
# Some parameters are not defined as flags (e.g. cannot run # Some parameters are not defined as flags (e.g. cannot run
...@@ -248,11 +243,10 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase): ...@@ -248,11 +243,10 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
'--mode=train_and_eval', '--mode=train_and_eval',
] ]
run = functools.partial(classifier_trainer.run, run = functools.partial(
strategy_override=distribution) classifier_trainer.run, strategy_override=distribution)
run_end_to_end(main=run, run_end_to_end(
extra_flags=train_and_eval_flags, main=run, extra_flags=train_and_eval_flags, model_dir=model_dir)
model_dir=model_dir)
@combinations.generate(distribution_strategy_combinations()) @combinations.generate(distribution_strategy_combinations())
def test_end_to_end_invalid_mode(self, distribution, model, dataset): def test_end_to_end_invalid_mode(self, distribution, model, dataset):
...@@ -266,8 +260,8 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase): ...@@ -266,8 +260,8 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
get_params_override(basic_params_override()), get_params_override(basic_params_override()),
] ]
run = functools.partial(classifier_trainer.run, run = functools.partial(
strategy_override=distribution) classifier_trainer.run, strategy_override=distribution)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
run_end_to_end(main=run, extra_flags=extra_flags, model_dir=model_dir) run_end_to_end(main=run, extra_flags=extra_flags, model_dir=model_dir)
...@@ -292,9 +286,7 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase): ...@@ -292,9 +286,7 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
model=base_configs.ModelConfig( model=base_configs.ModelConfig(
model_params={ model_params={
'model_name': model_name, 'model_name': model_name,
}, },))
)
)
size = classifier_trainer.get_image_size_from_model(config) size = classifier_trainer.get_image_size_from_model(config)
self.assertEqual(size, expected) self.assertEqual(size, expected)
...@@ -306,16 +298,13 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase): ...@@ -306,16 +298,13 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
) )
def test_get_loss_scale(self, loss_scale, dtype, expected): def test_get_loss_scale(self, loss_scale, dtype, expected):
config = base_configs.ExperimentConfig( config = base_configs.ExperimentConfig(
runtime=base_configs.RuntimeConfig( runtime=base_configs.RuntimeConfig(loss_scale=loss_scale),
loss_scale=loss_scale),
train_dataset=dataset_factory.DatasetConfig(dtype=dtype)) train_dataset=dataset_factory.DatasetConfig(dtype=dtype))
ls = classifier_trainer.get_loss_scale(config, fp16_default=128) ls = classifier_trainer.get_loss_scale(config, fp16_default=128)
self.assertEqual(ls, expected) self.assertEqual(ls, expected)
@parameterized.named_parameters( @parameterized.named_parameters(('float16', 'float16'),
('float16', 'float16'), ('bfloat16', 'bfloat16'))
('bfloat16', 'bfloat16')
)
def test_initialize(self, dtype): def test_initialize(self, dtype):
config = base_configs.ExperimentConfig( config = base_configs.ExperimentConfig(
runtime=base_configs.RuntimeConfig( runtime=base_configs.RuntimeConfig(
...@@ -332,6 +321,7 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase): ...@@ -332,6 +321,7 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
class EmptyClass: class EmptyClass:
pass pass
fake_ds_builder = EmptyClass() fake_ds_builder = EmptyClass()
fake_ds_builder.dtype = dtype fake_ds_builder.dtype = dtype
fake_ds_builder.config = EmptyClass() fake_ds_builder.config = EmptyClass()
...@@ -366,9 +356,7 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase): ...@@ -366,9 +356,7 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
clean_model = get_trivial_model(10) clean_model = get_trivial_model(10)
weights_before_load = copy.deepcopy(clean_model.get_weights()) weights_before_load = copy.deepcopy(clean_model.get_weights())
initial_epoch = classifier_trainer.resume_from_checkpoint( initial_epoch = classifier_trainer.resume_from_checkpoint(
model=clean_model, model=clean_model, model_dir=model_dir, train_steps=train_steps)
model_dir=model_dir,
train_steps=train_steps)
self.assertEqual(initial_epoch, 1) self.assertEqual(initial_epoch, 1)
self.assertNotAllClose(weights_before_load, clean_model.get_weights()) self.assertNotAllClose(weights_before_load, clean_model.get_weights())
...@@ -383,5 +371,6 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase): ...@@ -383,5 +371,6 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
self.assertTrue(os.path.exists(saved_params_path)) self.assertTrue(os.path.exists(saved_params_path))
tf.io.gfile.rmtree(model_dir) tf.io.gfile.rmtree(model_dir)
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -18,7 +18,6 @@ from __future__ import absolute_import ...@@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from typing import Any, List, Mapping, Optional from typing import Any, List, Mapping, Optional
import dataclasses import dataclasses
......
...@@ -37,7 +37,6 @@ class EfficientNetImageNetConfig(base_configs.ExperimentConfig): ...@@ -37,7 +37,6 @@ class EfficientNetImageNetConfig(base_configs.ExperimentConfig):
train: A `TrainConfig` instance. train: A `TrainConfig` instance.
evaluation: An `EvalConfig` instance. evaluation: An `EvalConfig` instance.
model: A `ModelConfig` instance. model: A `ModelConfig` instance.
""" """
export: base_configs.ExportConfig = base_configs.ExportConfig() export: base_configs.ExportConfig = base_configs.ExportConfig()
runtime: base_configs.RuntimeConfig = base_configs.RuntimeConfig() runtime: base_configs.RuntimeConfig = base_configs.RuntimeConfig()
...@@ -49,16 +48,15 @@ class EfficientNetImageNetConfig(base_configs.ExperimentConfig): ...@@ -49,16 +48,15 @@ class EfficientNetImageNetConfig(base_configs.ExperimentConfig):
resume_checkpoint=True, resume_checkpoint=True,
epochs=500, epochs=500,
steps=None, steps=None,
callbacks=base_configs.CallbacksConfig(enable_checkpoint_and_export=True, callbacks=base_configs.CallbacksConfig(
enable_tensorboard=True), enable_checkpoint_and_export=True, enable_tensorboard=True),
metrics=['accuracy', 'top_5'], metrics=['accuracy', 'top_5'],
time_history=base_configs.TimeHistoryConfig(log_steps=100), time_history=base_configs.TimeHistoryConfig(log_steps=100),
tensorboard=base_configs.TensorboardConfig(track_lr=True, tensorboard=base_configs.TensorboardConfig(
write_model_weights=False), track_lr=True, write_model_weights=False),
set_epoch_loop=False) set_epoch_loop=False)
evaluation: base_configs.EvalConfig = base_configs.EvalConfig( evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
epochs_between_evals=1, epochs_between_evals=1, steps=None)
steps=None)
model: base_configs.ModelConfig = \ model: base_configs.ModelConfig = \
efficientnet_config.EfficientNetModelConfig() efficientnet_config.EfficientNetModelConfig()
...@@ -82,16 +80,15 @@ class ResNetImagenetConfig(base_configs.ExperimentConfig): ...@@ -82,16 +80,15 @@ class ResNetImagenetConfig(base_configs.ExperimentConfig):
resume_checkpoint=True, resume_checkpoint=True,
epochs=90, epochs=90,
steps=None, steps=None,
callbacks=base_configs.CallbacksConfig(enable_checkpoint_and_export=True, callbacks=base_configs.CallbacksConfig(
enable_tensorboard=True), enable_checkpoint_and_export=True, enable_tensorboard=True),
metrics=['accuracy', 'top_5'], metrics=['accuracy', 'top_5'],
time_history=base_configs.TimeHistoryConfig(log_steps=100), time_history=base_configs.TimeHistoryConfig(log_steps=100),
tensorboard=base_configs.TensorboardConfig(track_lr=True, tensorboard=base_configs.TensorboardConfig(
write_model_weights=False), track_lr=True, write_model_weights=False),
set_epoch_loop=False) set_epoch_loop=False)
evaluation: base_configs.EvalConfig = base_configs.EvalConfig( evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
epochs_between_evals=1, epochs_between_evals=1, steps=None)
steps=None)
model: base_configs.ModelConfig = resnet_config.ResNetModelConfig() model: base_configs.ModelConfig = resnet_config.ResNetModelConfig()
...@@ -109,10 +106,8 @@ def get_config(model: str, dataset: str) -> base_configs.ExperimentConfig: ...@@ -109,10 +106,8 @@ def get_config(model: str, dataset: str) -> base_configs.ExperimentConfig:
if dataset not in dataset_model_config_map: if dataset not in dataset_model_config_map:
raise KeyError('Invalid dataset received. Received: {}. Supported ' raise KeyError('Invalid dataset received. Received: {}. Supported '
'datasets include: {}'.format( 'datasets include: {}'.format(
dataset, dataset, ', '.join(dataset_model_config_map.keys())))
', '.join(dataset_model_config_map.keys())))
raise KeyError('Invalid model received. Received: {}. Supported models for' raise KeyError('Invalid model received. Received: {}. Supported models for'
'{} include: {}'.format( '{} include: {}'.format(
model, model, dataset,
dataset,
', '.join(dataset_model_config_map[dataset].keys()))) ', '.join(dataset_model_config_map[dataset].keys())))
...@@ -21,6 +21,7 @@ from __future__ import print_function ...@@ -21,6 +21,7 @@ from __future__ import print_function
import os import os
from typing import Any, List, Optional, Tuple, Mapping, Union from typing import Any, List, Optional, Tuple, Mapping, Union
from absl import logging from absl import logging
from dataclasses import dataclass from dataclasses import dataclass
import tensorflow as tf import tensorflow as tf
...@@ -30,7 +31,6 @@ from official.modeling.hyperparams import base_config ...@@ -30,7 +31,6 @@ from official.modeling.hyperparams import base_config
from official.vision.image_classification import augment from official.vision.image_classification import augment
from official.vision.image_classification import preprocessing from official.vision.image_classification import preprocessing
AUGMENTERS = { AUGMENTERS = {
'autoaugment': augment.AutoAugment, 'autoaugment': augment.AutoAugment,
'randaugment': augment.RandAugment, 'randaugment': augment.RandAugment,
...@@ -42,8 +42,8 @@ class AugmentConfig(base_config.Config): ...@@ -42,8 +42,8 @@ class AugmentConfig(base_config.Config):
"""Configuration for image augmenters. """Configuration for image augmenters.
Attributes: Attributes:
name: The name of the image augmentation to use. Possible options are name: The name of the image augmentation to use. Possible options are None
None (default), 'autoaugment', or 'randaugment'. (default), 'autoaugment', or 'randaugment'.
params: Any paramaters used to initialize the augmenter. params: Any paramaters used to initialize the augmenter.
""" """
name: Optional[str] = None name: Optional[str] = None
...@@ -68,17 +68,17 @@ class DatasetConfig(base_config.Config): ...@@ -68,17 +68,17 @@ class DatasetConfig(base_config.Config):
'tfds' (load using TFDS), 'records' (load from TFRecords), or 'synthetic' 'tfds' (load using TFDS), 'records' (load from TFRecords), or 'synthetic'
(generate dummy synthetic data without reading from files). (generate dummy synthetic data without reading from files).
split: The split of the dataset. Usually 'train', 'validation', or 'test'. split: The split of the dataset. Usually 'train', 'validation', or 'test'.
image_size: The size of the image in the dataset. This assumes that image_size: The size of the image in the dataset. This assumes that `width`
`width` == `height`. Set to 'infer' to infer the image size from TFDS == `height`. Set to 'infer' to infer the image size from TFDS info. This
info. This requires `name` to be a registered dataset in TFDS. requires `name` to be a registered dataset in TFDS.
num_classes: The number of classes given by the dataset. Set to 'infer' num_classes: The number of classes given by the dataset. Set to 'infer' to
to infer the image size from TFDS info. This requires `name` to be a infer the image size from TFDS info. This requires `name` to be a
registered dataset in TFDS. registered dataset in TFDS.
num_channels: The number of channels given by the dataset. Set to 'infer' num_channels: The number of channels given by the dataset. Set to 'infer' to
to infer the image size from TFDS info. This requires `name` to be a infer the image size from TFDS info. This requires `name` to be a
registered dataset in TFDS. registered dataset in TFDS.
num_examples: The number of examples given by the dataset. Set to 'infer' num_examples: The number of examples given by the dataset. Set to 'infer' to
to infer the image size from TFDS info. This requires `name` to be a infer the image size from TFDS info. This requires `name` to be a
registered dataset in TFDS. registered dataset in TFDS.
batch_size: The base batch size for the dataset. batch_size: The base batch size for the dataset.
use_per_replica_batch_size: Whether to scale the batch size based on use_per_replica_batch_size: Whether to scale the batch size based on
...@@ -284,10 +284,10 @@ class DatasetBuilder: ...@@ -284,10 +284,10 @@ class DatasetBuilder:
""" """
if strategy: if strategy:
if strategy.num_replicas_in_sync != self.config.num_devices: if strategy.num_replicas_in_sync != self.config.num_devices:
logging.warn('Passed a strategy with %d devices, but expected' logging.warn(
'%d devices.', 'Passed a strategy with %d devices, but expected'
strategy.num_replicas_in_sync, '%d devices.', strategy.num_replicas_in_sync,
self.config.num_devices) self.config.num_devices)
dataset = strategy.experimental_distribute_datasets_from_function( dataset = strategy.experimental_distribute_datasets_from_function(
self._build) self._build)
else: else:
...@@ -295,8 +295,9 @@ class DatasetBuilder: ...@@ -295,8 +295,9 @@ class DatasetBuilder:
return dataset return dataset
def _build(self, input_context: tf.distribute.InputContext = None def _build(
) -> tf.data.Dataset: self,
input_context: tf.distribute.InputContext = None) -> tf.data.Dataset:
"""Construct a dataset end-to-end and return it. """Construct a dataset end-to-end and return it.
Args: Args:
...@@ -328,8 +329,7 @@ class DatasetBuilder: ...@@ -328,8 +329,7 @@ class DatasetBuilder:
logging.info('Using TFDS to load data.') logging.info('Using TFDS to load data.')
builder = tfds.builder(self.config.name, builder = tfds.builder(self.config.name, data_dir=self.config.data_dir)
data_dir=self.config.data_dir)
if self.config.download: if self.config.download:
builder.download_and_prepare() builder.download_and_prepare()
...@@ -380,8 +380,8 @@ class DatasetBuilder: ...@@ -380,8 +380,8 @@ class DatasetBuilder:
dataset = tf.data.Dataset.range(1) dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat() dataset = dataset.repeat()
dataset = dataset.map(generate_data, dataset = dataset.map(
num_parallel_calls=tf.data.experimental.AUTOTUNE) generate_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset return dataset
def pipeline(self, dataset: tf.data.Dataset) -> tf.data.Dataset: def pipeline(self, dataset: tf.data.Dataset) -> tf.data.Dataset:
...@@ -393,14 +393,14 @@ class DatasetBuilder: ...@@ -393,14 +393,14 @@ class DatasetBuilder:
Returns: Returns:
A TensorFlow dataset outputting batched images and labels. A TensorFlow dataset outputting batched images and labels.
""" """
if (self.config.builder != 'tfds' and self.input_context if (self.config.builder != 'tfds' and self.input_context and
and self.input_context.num_input_pipelines > 1): self.input_context.num_input_pipelines > 1):
dataset = dataset.shard(self.input_context.num_input_pipelines, dataset = dataset.shard(self.input_context.num_input_pipelines,
self.input_context.input_pipeline_id) self.input_context.input_pipeline_id)
logging.info('Sharding the dataset: input_pipeline_id=%d ' logging.info(
'num_input_pipelines=%d', 'Sharding the dataset: input_pipeline_id=%d '
self.input_context.num_input_pipelines, 'num_input_pipelines=%d', self.input_context.num_input_pipelines,
self.input_context.input_pipeline_id) self.input_context.input_pipeline_id)
if self.is_training and self.config.builder == 'records': if self.is_training and self.config.builder == 'records':
# Shuffle the input files. # Shuffle the input files.
...@@ -429,8 +429,8 @@ class DatasetBuilder: ...@@ -429,8 +429,8 @@ class DatasetBuilder:
preprocess = self.parse_record preprocess = self.parse_record
else: else:
preprocess = self.preprocess preprocess = self.preprocess
dataset = dataset.map(preprocess, dataset = dataset.map(
num_parallel_calls=tf.data.experimental.AUTOTUNE) preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE)
if self.input_context and self.config.num_devices > 1: if self.input_context and self.config.num_devices > 1:
if not self.config.use_per_replica_batch_size: if not self.config.use_per_replica_batch_size:
...@@ -444,11 +444,11 @@ class DatasetBuilder: ...@@ -444,11 +444,11 @@ class DatasetBuilder:
# The batch size of the dataset will be multiplied by the number of # The batch size of the dataset will be multiplied by the number of
# replicas automatically when strategy.distribute_datasets_from_function # replicas automatically when strategy.distribute_datasets_from_function
# is called, so we use local batch size here. # is called, so we use local batch size here.
dataset = dataset.batch(self.local_batch_size, dataset = dataset.batch(
drop_remainder=self.is_training) self.local_batch_size, drop_remainder=self.is_training)
else: else:
dataset = dataset.batch(self.global_batch_size, dataset = dataset.batch(
drop_remainder=self.is_training) self.global_batch_size, drop_remainder=self.is_training)
# Prefetch overlaps in-feed with training # Prefetch overlaps in-feed with training
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE) dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
...@@ -470,24 +470,15 @@ class DatasetBuilder: ...@@ -470,24 +470,15 @@ class DatasetBuilder:
def parse_record(self, record: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]: def parse_record(self, record: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
"""Parse an ImageNet record from a serialized string Tensor.""" """Parse an ImageNet record from a serialized string Tensor."""
keys_to_features = { keys_to_features = {
'image/encoded': 'image/encoded': tf.io.FixedLenFeature((), tf.string, ''),
tf.io.FixedLenFeature((), tf.string, ''), 'image/format': tf.io.FixedLenFeature((), tf.string, 'jpeg'),
'image/format': 'image/class/label': tf.io.FixedLenFeature([], tf.int64, -1),
tf.io.FixedLenFeature((), tf.string, 'jpeg'), 'image/class/text': tf.io.FixedLenFeature([], tf.string, ''),
'image/class/label': 'image/object/bbox/xmin': tf.io.VarLenFeature(dtype=tf.float32),
tf.io.FixedLenFeature([], tf.int64, -1), 'image/object/bbox/ymin': tf.io.VarLenFeature(dtype=tf.float32),
'image/class/text': 'image/object/bbox/xmax': tf.io.VarLenFeature(dtype=tf.float32),
tf.io.FixedLenFeature([], tf.string, ''), 'image/object/bbox/ymax': tf.io.VarLenFeature(dtype=tf.float32),
'image/object/bbox/xmin': 'image/object/class/label': tf.io.VarLenFeature(dtype=tf.int64),
tf.io.VarLenFeature(dtype=tf.float32),
'image/object/bbox/ymin':
tf.io.VarLenFeature(dtype=tf.float32),
'image/object/bbox/xmax':
tf.io.VarLenFeature(dtype=tf.float32),
'image/object/bbox/ymax':
tf.io.VarLenFeature(dtype=tf.float32),
'image/object/class/label':
tf.io.VarLenFeature(dtype=tf.int64),
} }
parsed = tf.io.parse_single_example(record, keys_to_features) parsed = tf.io.parse_single_example(record, keys_to_features)
...@@ -502,8 +493,8 @@ class DatasetBuilder: ...@@ -502,8 +493,8 @@ class DatasetBuilder:
return image, label return image, label
def preprocess(self, image: tf.Tensor, label: tf.Tensor def preprocess(self, image: tf.Tensor,
) -> Tuple[tf.Tensor, tf.Tensor]: label: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
"""Apply image preprocessing and augmentation to the image and label.""" """Apply image preprocessing and augmentation to the image and label."""
if self.is_training: if self.is_training:
image = preprocessing.preprocess_for_train( image = preprocessing.preprocess_for_train(
......
...@@ -79,7 +79,7 @@ def get_batch_norm(batch_norm_type: Text) -> tf.keras.layers.BatchNormalization: ...@@ -79,7 +79,7 @@ def get_batch_norm(batch_norm_type: Text) -> tf.keras.layers.BatchNormalization:
Args: Args:
batch_norm_type: The type of batch normalization layer implementation. `tpu` batch_norm_type: The type of batch normalization layer implementation. `tpu`
will use `TpuBatchNormalization`. will use `TpuBatchNormalization`.
Returns: Returns:
An instance of `tf.keras.layers.BatchNormalization`. An instance of `tf.keras.layers.BatchNormalization`.
...@@ -95,8 +95,10 @@ def count_params(model, trainable_only=True): ...@@ -95,8 +95,10 @@ def count_params(model, trainable_only=True):
if not trainable_only: if not trainable_only:
return model.count_params() return model.count_params()
else: else:
return int(np.sum([tf.keras.backend.count_params(p) return int(
for p in model.trainable_weights])) np.sum([
tf.keras.backend.count_params(p) for p in model.trainable_weights
]))
def load_weights(model: tf.keras.Model, def load_weights(model: tf.keras.Model,
...@@ -107,8 +109,8 @@ def load_weights(model: tf.keras.Model, ...@@ -107,8 +109,8 @@ def load_weights(model: tf.keras.Model,
Args: Args:
model: the model to load weights into model: the model to load weights into
model_weights_path: the path of the model weights model_weights_path: the path of the model weights
weights_format: the model weights format. One of 'saved_model', 'h5', weights_format: the model weights format. One of 'saved_model', 'h5', or
or 'checkpoint'. 'checkpoint'.
""" """
if weights_format == 'saved_model': if weights_format == 'saved_model':
loaded_model = tf.keras.models.load_model(model_weights_path) loaded_model = tf.keras.models.load_model(model_weights_path)
......
...@@ -64,11 +64,11 @@ class ModelConfig(base_config.Config): ...@@ -64,11 +64,11 @@ class ModelConfig(base_config.Config):
# (input_filters, output_filters, kernel_size, num_repeat, # (input_filters, output_filters, kernel_size, num_repeat,
# expand_ratio, strides, se_ratio) # expand_ratio, strides, se_ratio)
# pylint: disable=bad-whitespace # pylint: disable=bad-whitespace
BlockConfig.from_args(32, 16, 3, 1, 1, (1, 1), 0.25), BlockConfig.from_args(32, 16, 3, 1, 1, (1, 1), 0.25),
BlockConfig.from_args(16, 24, 3, 2, 6, (2, 2), 0.25), BlockConfig.from_args(16, 24, 3, 2, 6, (2, 2), 0.25),
BlockConfig.from_args(24, 40, 5, 2, 6, (2, 2), 0.25), BlockConfig.from_args(24, 40, 5, 2, 6, (2, 2), 0.25),
BlockConfig.from_args(40, 80, 3, 3, 6, (2, 2), 0.25), BlockConfig.from_args(40, 80, 3, 3, 6, (2, 2), 0.25),
BlockConfig.from_args(80, 112, 5, 3, 6, (1, 1), 0.25), BlockConfig.from_args(80, 112, 5, 3, 6, (1, 1), 0.25),
BlockConfig.from_args(112, 192, 5, 4, 6, (2, 2), 0.25), BlockConfig.from_args(112, 192, 5, 4, 6, (2, 2), 0.25),
BlockConfig.from_args(192, 320, 3, 1, 6, (1, 1), 0.25), BlockConfig.from_args(192, 320, 3, 1, 6, (1, 1), 0.25),
# pylint: enable=bad-whitespace # pylint: enable=bad-whitespace
...@@ -128,8 +128,7 @@ DENSE_KERNEL_INITIALIZER = { ...@@ -128,8 +128,7 @@ DENSE_KERNEL_INITIALIZER = {
} }
def round_filters(filters: int, def round_filters(filters: int, config: ModelConfig) -> int:
config: ModelConfig) -> int:
"""Round number of filters based on width coefficient.""" """Round number of filters based on width coefficient."""
width_coefficient = config.width_coefficient width_coefficient = config.width_coefficient
min_depth = config.min_depth min_depth = config.min_depth
...@@ -189,21 +188,24 @@ def conv2d_block(inputs: tf.Tensor, ...@@ -189,21 +188,24 @@ def conv2d_block(inputs: tf.Tensor,
init_kwargs.update({'depthwise_initializer': CONV_KERNEL_INITIALIZER}) init_kwargs.update({'depthwise_initializer': CONV_KERNEL_INITIALIZER})
else: else:
conv2d = tf.keras.layers.Conv2D conv2d = tf.keras.layers.Conv2D
init_kwargs.update({'filters': conv_filters, init_kwargs.update({
'kernel_initializer': CONV_KERNEL_INITIALIZER}) 'filters': conv_filters,
'kernel_initializer': CONV_KERNEL_INITIALIZER
})
x = conv2d(**init_kwargs)(inputs) x = conv2d(**init_kwargs)(inputs)
if use_batch_norm: if use_batch_norm:
bn_axis = 1 if data_format == 'channels_first' else -1 bn_axis = 1 if data_format == 'channels_first' else -1
x = batch_norm(axis=bn_axis, x = batch_norm(
momentum=bn_momentum, axis=bn_axis,
epsilon=bn_epsilon, momentum=bn_momentum,
name=name + '_bn')(x) epsilon=bn_epsilon,
name=name + '_bn')(
x)
if activation is not None: if activation is not None:
x = tf.keras.layers.Activation(activation, x = tf.keras.layers.Activation(activation, name=name + '_activation')(x)
name=name + '_activation')(x)
return x return x
...@@ -235,42 +237,43 @@ def mb_conv_block(inputs: tf.Tensor, ...@@ -235,42 +237,43 @@ def mb_conv_block(inputs: tf.Tensor,
if block.fused_conv: if block.fused_conv:
# If we use fused mbconv, skip expansion and use regular conv. # If we use fused mbconv, skip expansion and use regular conv.
x = conv2d_block(x, x = conv2d_block(
filters, x,
config, filters,
kernel_size=block.kernel_size, config,
strides=block.strides, kernel_size=block.kernel_size,
activation=activation, strides=block.strides,
name=prefix + 'fused') activation=activation,
name=prefix + 'fused')
else: else:
if block.expand_ratio != 1: if block.expand_ratio != 1:
# Expansion phase # Expansion phase
kernel_size = (1, 1) if use_depthwise else (3, 3) kernel_size = (1, 1) if use_depthwise else (3, 3)
x = conv2d_block(x, x = conv2d_block(
filters, x,
config, filters,
kernel_size=kernel_size, config,
activation=activation, kernel_size=kernel_size,
name=prefix + 'expand') activation=activation,
name=prefix + 'expand')
# Depthwise Convolution # Depthwise Convolution
if use_depthwise: if use_depthwise:
x = conv2d_block(x, x = conv2d_block(
conv_filters=None, x,
config=config, conv_filters=None,
kernel_size=block.kernel_size, config=config,
strides=block.strides, kernel_size=block.kernel_size,
activation=activation, strides=block.strides,
depthwise=True, activation=activation,
name=prefix + 'depthwise') depthwise=True,
name=prefix + 'depthwise')
# Squeeze and Excitation phase # Squeeze and Excitation phase
if use_se: if use_se:
assert block.se_ratio is not None assert block.se_ratio is not None
assert 0 < block.se_ratio <= 1 assert 0 < block.se_ratio <= 1
num_reduced_filters = max(1, int( num_reduced_filters = max(1, int(block.input_filters * block.se_ratio))
block.input_filters * block.se_ratio
))
if data_format == 'channels_first': if data_format == 'channels_first':
se_shape = (filters, 1, 1) se_shape = (filters, 1, 1)
...@@ -280,53 +283,51 @@ def mb_conv_block(inputs: tf.Tensor, ...@@ -280,53 +283,51 @@ def mb_conv_block(inputs: tf.Tensor,
se = tf.keras.layers.GlobalAveragePooling2D(name=prefix + 'se_squeeze')(x) se = tf.keras.layers.GlobalAveragePooling2D(name=prefix + 'se_squeeze')(x)
se = tf.keras.layers.Reshape(se_shape, name=prefix + 'se_reshape')(se) se = tf.keras.layers.Reshape(se_shape, name=prefix + 'se_reshape')(se)
se = conv2d_block(se, se = conv2d_block(
num_reduced_filters, se,
config, num_reduced_filters,
use_bias=True, config,
use_batch_norm=False, use_bias=True,
activation=activation, use_batch_norm=False,
name=prefix + 'se_reduce') activation=activation,
se = conv2d_block(se, name=prefix + 'se_reduce')
filters, se = conv2d_block(
config, se,
use_bias=True, filters,
use_batch_norm=False, config,
activation='sigmoid', use_bias=True,
name=prefix + 'se_expand') use_batch_norm=False,
activation='sigmoid',
name=prefix + 'se_expand')
x = tf.keras.layers.multiply([x, se], name=prefix + 'se_excite') x = tf.keras.layers.multiply([x, se], name=prefix + 'se_excite')
# Output phase # Output phase
x = conv2d_block(x, x = conv2d_block(
block.output_filters, x, block.output_filters, config, activation=None, name=prefix + 'project')
config,
activation=None,
name=prefix + 'project')
# Add identity so that quantization-aware training can insert quantization # Add identity so that quantization-aware training can insert quantization
# ops correctly. # ops correctly.
x = tf.keras.layers.Activation(tf_utils.get_activation('identity'), x = tf.keras.layers.Activation(
name=prefix + 'id')(x) tf_utils.get_activation('identity'), name=prefix + 'id')(
x)
if (block.id_skip if (block.id_skip and all(s == 1 for s in block.strides) and
and all(s == 1 for s in block.strides) block.input_filters == block.output_filters):
and block.input_filters == block.output_filters):
if drop_connect_rate and drop_connect_rate > 0: if drop_connect_rate and drop_connect_rate > 0:
# Apply dropconnect # Apply dropconnect
# The only difference between dropout and dropconnect in TF is scaling by # The only difference between dropout and dropconnect in TF is scaling by
# drop_connect_rate during training. See: # drop_connect_rate during training. See:
# https://github.com/keras-team/keras/pull/9898#issuecomment-380577612 # https://github.com/keras-team/keras/pull/9898#issuecomment-380577612
x = tf.keras.layers.Dropout(drop_connect_rate, x = tf.keras.layers.Dropout(
noise_shape=(None, 1, 1, 1), drop_connect_rate, noise_shape=(None, 1, 1, 1), name=prefix + 'drop')(
name=prefix + 'drop')(x) x)
x = tf.keras.layers.add([x, inputs], name=prefix + 'add') x = tf.keras.layers.add([x, inputs], name=prefix + 'add')
return x return x
def efficientnet(image_input: tf.keras.layers.Input, def efficientnet(image_input: tf.keras.layers.Input, config: ModelConfig):
config: ModelConfig):
"""Creates an EfficientNet graph given the model parameters. """Creates an EfficientNet graph given the model parameters.
This function is wrapped by the `EfficientNet` class to make a tf.keras.Model. This function is wrapped by the `EfficientNet` class to make a tf.keras.Model.
...@@ -357,19 +358,18 @@ def efficientnet(image_input: tf.keras.layers.Input, ...@@ -357,19 +358,18 @@ def efficientnet(image_input: tf.keras.layers.Input,
# Happens on GPU/TPU if available. # Happens on GPU/TPU if available.
x = tf.keras.layers.Permute((3, 1, 2))(x) x = tf.keras.layers.Permute((3, 1, 2))(x)
if rescale_input: if rescale_input:
x = preprocessing.normalize_images(x, x = preprocessing.normalize_images(
num_channels=input_channels, x, num_channels=input_channels, dtype=dtype, data_format=data_format)
dtype=dtype,
data_format=data_format)
# Build stem # Build stem
x = conv2d_block(x, x = conv2d_block(
round_filters(stem_base_filters, config), x,
config, round_filters(stem_base_filters, config),
kernel_size=[3, 3], config,
strides=[2, 2], kernel_size=[3, 3],
activation=activation, strides=[2, 2],
name='stem') activation=activation,
name='stem')
# Build blocks # Build blocks
num_blocks_total = sum( num_blocks_total = sum(
...@@ -391,10 +391,7 @@ def efficientnet(image_input: tf.keras.layers.Input, ...@@ -391,10 +391,7 @@ def efficientnet(image_input: tf.keras.layers.Input,
x = mb_conv_block(x, block, config, block_prefix) x = mb_conv_block(x, block, config, block_prefix)
block_num += 1 block_num += 1
if block.num_repeat > 1: if block.num_repeat > 1:
block = block.replace( block = block.replace(input_filters=block.output_filters, strides=[1, 1])
input_filters=block.output_filters,
strides=[1, 1]
)
for block_idx in range(block.num_repeat - 1): for block_idx in range(block.num_repeat - 1):
drop_rate = drop_connect_rate * float(block_num) / num_blocks_total drop_rate = drop_connect_rate * float(block_num) / num_blocks_total
...@@ -404,11 +401,12 @@ def efficientnet(image_input: tf.keras.layers.Input, ...@@ -404,11 +401,12 @@ def efficientnet(image_input: tf.keras.layers.Input,
block_num += 1 block_num += 1
# Build top # Build top
x = conv2d_block(x, x = conv2d_block(
round_filters(top_base_filters, config), x,
config, round_filters(top_base_filters, config),
activation=activation, config,
name='top') activation=activation,
name='top')
# Build classifier # Build classifier
x = tf.keras.layers.GlobalAveragePooling2D(name='top_pool')(x) x = tf.keras.layers.GlobalAveragePooling2D(name='top_pool')(x)
...@@ -419,7 +417,8 @@ def efficientnet(image_input: tf.keras.layers.Input, ...@@ -419,7 +417,8 @@ def efficientnet(image_input: tf.keras.layers.Input,
kernel_initializer=DENSE_KERNEL_INITIALIZER, kernel_initializer=DENSE_KERNEL_INITIALIZER,
kernel_regularizer=tf.keras.regularizers.l2(weight_decay), kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
bias_regularizer=tf.keras.regularizers.l2(weight_decay), bias_regularizer=tf.keras.regularizers.l2(weight_decay),
name='logits')(x) name='logits')(
x)
x = tf.keras.layers.Activation('softmax', name='probs')(x) x = tf.keras.layers.Activation('softmax', name='probs')(x)
return x return x
...@@ -439,8 +438,7 @@ class EfficientNet(tf.keras.Model): ...@@ -439,8 +438,7 @@ class EfficientNet(tf.keras.Model):
Args: Args:
config: (optional) the main model parameters to create the model config: (optional) the main model parameters to create the model
overrides: (optional) a dict containing keys that can override overrides: (optional) a dict containing keys that can override config
config
""" """
overrides = overrides or {} overrides = overrides or {}
config = config or ModelConfig() config = config or ModelConfig()
...@@ -457,9 +455,7 @@ class EfficientNet(tf.keras.Model): ...@@ -457,9 +455,7 @@ class EfficientNet(tf.keras.Model):
# Cast to float32 in case we have a different model dtype # Cast to float32 in case we have a different model dtype
output = tf.cast(output, tf.float32) output = tf.cast(output, tf.float32)
logging.info('Building model %s with params %s', logging.info('Building model %s with params %s', model_name, self.config)
model_name,
self.config)
super(EfficientNet, self).__init__( super(EfficientNet, self).__init__(
inputs=image_input, outputs=output, name=model_name) inputs=image_input, outputs=output, name=model_name)
...@@ -477,8 +473,8 @@ class EfficientNet(tf.keras.Model): ...@@ -477,8 +473,8 @@ class EfficientNet(tf.keras.Model):
Args: Args:
model_name: the predefined model name model_name: the predefined model name
model_weights_path: the path to the weights (h5 file or saved model dir) model_weights_path: the path to the weights (h5 file or saved model dir)
weights_format: the model weights format. One of 'saved_model', 'h5', weights_format: the model weights format. One of 'saved_model', 'h5', or
or 'checkpoint'. 'checkpoint'.
overrides: (optional) a dict containing keys that can override config overrides: (optional) a dict containing keys that can override config
Returns: Returns:
...@@ -498,8 +494,7 @@ class EfficientNet(tf.keras.Model): ...@@ -498,8 +494,7 @@ class EfficientNet(tf.keras.Model):
model = cls(config=config, overrides=overrides) model = cls(config=config, overrides=overrides)
if model_weights_path: if model_weights_path:
common_modules.load_weights(model, common_modules.load_weights(
model_weights_path, model, model_weights_path, weights_format=weights_format)
weights_format=weights_format)
return model return model
...@@ -30,10 +30,8 @@ from official.vision.image_classification.efficientnet import efficientnet_model ...@@ -30,10 +30,8 @@ from official.vision.image_classification.efficientnet import efficientnet_model
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
flags.DEFINE_string("model_name", None, flags.DEFINE_string("model_name", None, "EfficientNet model name.")
"EfficientNet model name.") flags.DEFINE_string("model_path", None, "File path to TF model checkpoint.")
flags.DEFINE_string("model_path", None,
"File path to TF model checkpoint.")
flags.DEFINE_string("export_path", None, flags.DEFINE_string("export_path", None,
"TF-Hub SavedModel destination path to export.") "TF-Hub SavedModel destination path to export.")
...@@ -65,5 +63,6 @@ def main(argv): ...@@ -65,5 +63,6 @@ def main(argv):
export_tfhub(FLAGS.model_path, FLAGS.export_path, FLAGS.model_name) export_tfhub(FLAGS.model_path, FLAGS.export_path, FLAGS.model_name)
if __name__ == "__main__": if __name__ == "__main__":
app.run(main) app.run(main)
...@@ -29,11 +29,10 @@ BASE_LEARNING_RATE = 0.1 ...@@ -29,11 +29,10 @@ BASE_LEARNING_RATE = 0.1
class WarmupDecaySchedule(tf.keras.optimizers.schedules.LearningRateSchedule): class WarmupDecaySchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
"""A wrapper for LearningRateSchedule that includes warmup steps.""" """A wrapper for LearningRateSchedule that includes warmup steps."""
def __init__( def __init__(self,
self, lr_schedule: tf.keras.optimizers.schedules.LearningRateSchedule,
lr_schedule: tf.keras.optimizers.schedules.LearningRateSchedule, warmup_steps: int,
warmup_steps: int, warmup_lr: Optional[float] = None):
warmup_lr: Optional[float] = None):
"""Add warmup decay to a learning rate schedule. """Add warmup decay to a learning rate schedule.
Args: Args:
...@@ -42,7 +41,6 @@ class WarmupDecaySchedule(tf.keras.optimizers.schedules.LearningRateSchedule): ...@@ -42,7 +41,6 @@ class WarmupDecaySchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
warmup_lr: an optional field for the final warmup learning rate. This warmup_lr: an optional field for the final warmup learning rate. This
should be provided if the base `lr_schedule` does not contain this should be provided if the base `lr_schedule` does not contain this
field. field.
""" """
super(WarmupDecaySchedule, self).__init__() super(WarmupDecaySchedule, self).__init__()
self._lr_schedule = lr_schedule self._lr_schedule = lr_schedule
...@@ -63,8 +61,7 @@ class WarmupDecaySchedule(tf.keras.optimizers.schedules.LearningRateSchedule): ...@@ -63,8 +61,7 @@ class WarmupDecaySchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
global_step_recomp = tf.cast(step, dtype) global_step_recomp = tf.cast(step, dtype)
warmup_steps = tf.cast(self._warmup_steps, dtype) warmup_steps = tf.cast(self._warmup_steps, dtype)
warmup_lr = initial_learning_rate * global_step_recomp / warmup_steps warmup_lr = initial_learning_rate * global_step_recomp / warmup_steps
lr = tf.cond(global_step_recomp < warmup_steps, lr = tf.cond(global_step_recomp < warmup_steps, lambda: warmup_lr,
lambda: warmup_lr,
lambda: lr) lambda: lr)
return lr return lr
......
...@@ -37,14 +37,13 @@ class LearningRateTests(tf.test.TestCase): ...@@ -37,14 +37,13 @@ class LearningRateTests(tf.test.TestCase):
decay_steps=decay_steps, decay_steps=decay_steps,
decay_rate=decay_rate) decay_rate=decay_rate)
lr = learning_rate.WarmupDecaySchedule( lr = learning_rate.WarmupDecaySchedule(
lr_schedule=base_lr, lr_schedule=base_lr, warmup_steps=warmup_steps)
warmup_steps=warmup_steps)
for step in range(warmup_steps - 1): for step in range(warmup_steps - 1):
config = lr.get_config() config = lr.get_config()
self.assertEqual(config['warmup_steps'], warmup_steps) self.assertEqual(config['warmup_steps'], warmup_steps)
self.assertAllClose(self.evaluate(lr(step)), self.assertAllClose(
step / warmup_steps * initial_lr) self.evaluate(lr(step)), step / warmup_steps * initial_lr)
def test_cosine_decay_with_warmup(self): def test_cosine_decay_with_warmup(self):
"""Basic computational test for cosine decay with warmup.""" """Basic computational test for cosine decay with warmup."""
......
...@@ -19,6 +19,7 @@ from __future__ import print_function ...@@ -19,6 +19,7 @@ from __future__ import print_function
import os import os
# Import libraries
from absl import app from absl import app
from absl import flags from absl import flags
from absl import logging from absl import logging
......
...@@ -58,7 +58,8 @@ class KerasMnistTest(tf.test.TestCase, parameterized.TestCase): ...@@ -58,7 +58,8 @@ class KerasMnistTest(tf.test.TestCase, parameterized.TestCase):
"""Test Keras MNIST model with `strategy`.""" """Test Keras MNIST model with `strategy`."""
extra_flags = [ extra_flags = [
"-train_epochs", "1", "-train_epochs",
"1",
# Let TFDS find the metadata folder automatically # Let TFDS find the metadata folder automatically
"--data_dir=" "--data_dir="
] ]
...@@ -72,9 +73,10 @@ class KerasMnistTest(tf.test.TestCase, parameterized.TestCase): ...@@ -72,9 +73,10 @@ class KerasMnistTest(tf.test.TestCase, parameterized.TestCase):
tf.data.Dataset.from_tensor_slices(dummy_data), tf.data.Dataset.from_tensor_slices(dummy_data),
) )
run = functools.partial(mnist_main.run, run = functools.partial(
datasets_override=datasets, mnist_main.run,
strategy_override=distribution) datasets_override=datasets,
strategy_override=distribution)
integration.run_synthetic( integration.run_synthetic(
main=run, main=run,
......
...@@ -65,19 +65,19 @@ class MovingAverage(tf.keras.optimizers.Optimizer): ...@@ -65,19 +65,19 @@ class MovingAverage(tf.keras.optimizers.Optimizer):
"""Construct a new MovingAverage optimizer. """Construct a new MovingAverage optimizer.
Args: Args:
optimizer: `tf.keras.optimizers.Optimizer` that will be optimizer: `tf.keras.optimizers.Optimizer` that will be used to compute
used to compute and apply gradients. and apply gradients.
average_decay: float. Decay to use to maintain the moving averages average_decay: float. Decay to use to maintain the moving averages of
of trained variables. trained variables.
start_step: int. What step to start the moving average. start_step: int. What step to start the moving average.
dynamic_decay: bool. Whether to change the decay based on the number dynamic_decay: bool. Whether to change the decay based on the number of
of optimizer updates. Decay will start at 0.1 and gradually increase optimizer updates. Decay will start at 0.1 and gradually increase up to
up to `average_decay` after each optimizer update. This behavior is `average_decay` after each optimizer update. This behavior is similar to
similar to `tf.train.ExponentialMovingAverage` in TF 1.x. `tf.train.ExponentialMovingAverage` in TF 1.x.
name: Optional name for the operations created when applying name: Optional name for the operations created when applying gradients.
gradients. Defaults to "moving_average". Defaults to "moving_average".
**kwargs: keyword arguments. Allowed to be {`clipnorm`, **kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`,
`clipvalue`, `lr`, `decay`}. `decay`}.
""" """
super(MovingAverage, self).__init__(name, **kwargs) super(MovingAverage, self).__init__(name, **kwargs)
self._optimizer = optimizer self._optimizer = optimizer
...@@ -128,8 +128,8 @@ class MovingAverage(tf.keras.optimizers.Optimizer): ...@@ -128,8 +128,8 @@ class MovingAverage(tf.keras.optimizers.Optimizer):
strategy.extended.update(v_moving, _apply_moving, args=(v_normal,)) strategy.extended.update(v_moving, _apply_moving, args=(v_normal,))
ctx = tf.distribute.get_replica_context() ctx = tf.distribute.get_replica_context()
return ctx.merge_call(_update, args=(zip(self._average_weights, return ctx.merge_call(
self._model_weights),)) _update, args=(zip(self._average_weights, self._model_weights),))
def swap_weights(self): def swap_weights(self):
"""Swap the average and moving weights. """Swap the average and moving weights.
...@@ -148,12 +148,15 @@ class MovingAverage(tf.keras.optimizers.Optimizer): ...@@ -148,12 +148,15 @@ class MovingAverage(tf.keras.optimizers.Optimizer):
@tf.function @tf.function
def _swap_weights(self): def _swap_weights(self):
def fn_0(a, b): def fn_0(a, b):
a.assign_add(b) a.assign_add(b)
return a return a
def fn_1(b, a): def fn_1(b, a):
b.assign(a - b) b.assign(a - b)
return b return b
def fn_2(a, b): def fn_2(a, b):
a.assign_sub(b) a.assign_sub(b)
return a return a
...@@ -174,12 +177,14 @@ class MovingAverage(tf.keras.optimizers.Optimizer): ...@@ -174,12 +177,14 @@ class MovingAverage(tf.keras.optimizers.Optimizer):
Args: Args:
var_list: List of model variables to be assigned to their average. var_list: List of model variables to be assigned to their average.
Returns: Returns:
assign_op: The op corresponding to the assignment operation of assign_op: The op corresponding to the assignment operation of
variables to their average. variables to their average.
""" """
assign_op = tf.group([ assign_op = tf.group([
var.assign(self.get_slot(var, 'average')) for var in var_list var.assign(self.get_slot(var, 'average'))
for var in var_list
if var.trainable if var.trainable
]) ])
return assign_op return assign_op
...@@ -256,13 +261,13 @@ def build_optimizer( ...@@ -256,13 +261,13 @@ def build_optimizer(
"""Build the optimizer based on name. """Build the optimizer based on name.
Args: Args:
optimizer_name: String representation of the optimizer name. Examples: optimizer_name: String representation of the optimizer name. Examples: sgd,
sgd, momentum, rmsprop. momentum, rmsprop.
base_learning_rate: `tf.keras.optimizers.schedules.LearningRateSchedule` base_learning_rate: `tf.keras.optimizers.schedules.LearningRateSchedule`
base learning rate. base learning rate.
params: String -> Any dictionary representing the optimizer params. params: String -> Any dictionary representing the optimizer params. This
This should contain optimizer specific parameters such as should contain optimizer specific parameters such as `base_learning_rate`,
`base_learning_rate`, `decay`, etc. `decay`, etc.
model: The `tf.keras.Model`. This is used for the shadow copy if using model: The `tf.keras.Model`. This is used for the shadow copy if using
`MovingAverage`. `MovingAverage`.
...@@ -279,43 +284,47 @@ def build_optimizer( ...@@ -279,43 +284,47 @@ def build_optimizer(
if optimizer_name == 'sgd': if optimizer_name == 'sgd':
logging.info('Using SGD optimizer') logging.info('Using SGD optimizer')
nesterov = params.get('nesterov', False) nesterov = params.get('nesterov', False)
optimizer = tf.keras.optimizers.SGD(learning_rate=base_learning_rate, optimizer = tf.keras.optimizers.SGD(
nesterov=nesterov) learning_rate=base_learning_rate, nesterov=nesterov)
elif optimizer_name == 'momentum': elif optimizer_name == 'momentum':
logging.info('Using momentum optimizer') logging.info('Using momentum optimizer')
nesterov = params.get('nesterov', False) nesterov = params.get('nesterov', False)
optimizer = tf.keras.optimizers.SGD(learning_rate=base_learning_rate, optimizer = tf.keras.optimizers.SGD(
momentum=params['momentum'], learning_rate=base_learning_rate,
nesterov=nesterov) momentum=params['momentum'],
nesterov=nesterov)
elif optimizer_name == 'rmsprop': elif optimizer_name == 'rmsprop':
logging.info('Using RMSProp') logging.info('Using RMSProp')
rho = params.get('decay', None) or params.get('rho', 0.9) rho = params.get('decay', None) or params.get('rho', 0.9)
momentum = params.get('momentum', 0.9) momentum = params.get('momentum', 0.9)
epsilon = params.get('epsilon', 1e-07) epsilon = params.get('epsilon', 1e-07)
optimizer = tf.keras.optimizers.RMSprop(learning_rate=base_learning_rate, optimizer = tf.keras.optimizers.RMSprop(
rho=rho, learning_rate=base_learning_rate,
momentum=momentum, rho=rho,
epsilon=epsilon) momentum=momentum,
epsilon=epsilon)
elif optimizer_name == 'adam': elif optimizer_name == 'adam':
logging.info('Using Adam') logging.info('Using Adam')
beta_1 = params.get('beta_1', 0.9) beta_1 = params.get('beta_1', 0.9)
beta_2 = params.get('beta_2', 0.999) beta_2 = params.get('beta_2', 0.999)
epsilon = params.get('epsilon', 1e-07) epsilon = params.get('epsilon', 1e-07)
optimizer = tf.keras.optimizers.Adam(learning_rate=base_learning_rate, optimizer = tf.keras.optimizers.Adam(
beta_1=beta_1, learning_rate=base_learning_rate,
beta_2=beta_2, beta_1=beta_1,
epsilon=epsilon) beta_2=beta_2,
epsilon=epsilon)
elif optimizer_name == 'adamw': elif optimizer_name == 'adamw':
logging.info('Using AdamW') logging.info('Using AdamW')
weight_decay = params.get('weight_decay', 0.01) weight_decay = params.get('weight_decay', 0.01)
beta_1 = params.get('beta_1', 0.9) beta_1 = params.get('beta_1', 0.9)
beta_2 = params.get('beta_2', 0.999) beta_2 = params.get('beta_2', 0.999)
epsilon = params.get('epsilon', 1e-07) epsilon = params.get('epsilon', 1e-07)
optimizer = tfa.optimizers.AdamW(weight_decay=weight_decay, optimizer = tfa.optimizers.AdamW(
learning_rate=base_learning_rate, weight_decay=weight_decay,
beta_1=beta_1, learning_rate=base_learning_rate,
beta_2=beta_2, beta_1=beta_1,
epsilon=epsilon) beta_2=beta_2,
epsilon=epsilon)
else: else:
raise ValueError('Unknown optimizer %s' % optimizer_name) raise ValueError('Unknown optimizer %s' % optimizer_name)
...@@ -330,8 +339,7 @@ def build_optimizer( ...@@ -330,8 +339,7 @@ def build_optimizer(
raise ValueError('`model` must be provided if using `MovingAverage`.') raise ValueError('`model` must be provided if using `MovingAverage`.')
logging.info('Including moving average decay.') logging.info('Including moving average decay.')
optimizer = MovingAverage( optimizer = MovingAverage(
optimizer=optimizer, optimizer=optimizer, average_decay=moving_average_decay)
average_decay=moving_average_decay)
optimizer.shadow_copy(model) optimizer.shadow_copy(model)
return optimizer return optimizer
...@@ -358,13 +366,15 @@ def build_learning_rate(params: base_configs.LearningRateConfig, ...@@ -358,13 +366,15 @@ def build_learning_rate(params: base_configs.LearningRateConfig,
if lr_multiplier and lr_multiplier > 0: if lr_multiplier and lr_multiplier > 0:
# Scale the learning rate based on the batch size and a multiplier # Scale the learning rate based on the batch size and a multiplier
base_lr *= lr_multiplier * batch_size base_lr *= lr_multiplier * batch_size
logging.info('Scaling the learning rate based on the batch size ' logging.info(
'multiplier. New base_lr: %f', base_lr) 'Scaling the learning rate based on the batch size '
'multiplier. New base_lr: %f', base_lr)
if decay_type == 'exponential': if decay_type == 'exponential':
logging.info('Using exponential learning rate with: ' logging.info(
'initial_learning_rate: %f, decay_steps: %d, ' 'Using exponential learning rate with: '
'decay_rate: %f', base_lr, decay_steps, decay_rate) 'initial_learning_rate: %f, decay_steps: %d, '
'decay_rate: %f', base_lr, decay_steps, decay_rate)
lr = tf.keras.optimizers.schedules.ExponentialDecay( lr = tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=base_lr, initial_learning_rate=base_lr,
decay_steps=decay_steps, decay_steps=decay_steps,
...@@ -374,12 +384,11 @@ def build_learning_rate(params: base_configs.LearningRateConfig, ...@@ -374,12 +384,11 @@ def build_learning_rate(params: base_configs.LearningRateConfig,
steps_per_epoch = params.examples_per_epoch // batch_size steps_per_epoch = params.examples_per_epoch // batch_size
boundaries = [boundary * steps_per_epoch for boundary in params.boundaries] boundaries = [boundary * steps_per_epoch for boundary in params.boundaries]
multipliers = [batch_size * multiplier for multiplier in params.multipliers] multipliers = [batch_size * multiplier for multiplier in params.multipliers]
logging.info('Using stepwise learning rate. Parameters: ' logging.info(
'boundaries: %s, values: %s', 'Using stepwise learning rate. Parameters: '
boundaries, multipliers) 'boundaries: %s, values: %s', boundaries, multipliers)
lr = tf.keras.optimizers.schedules.PiecewiseConstantDecay( lr = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
boundaries=boundaries, boundaries=boundaries, values=multipliers)
values=multipliers)
elif decay_type == 'cosine_with_warmup': elif decay_type == 'cosine_with_warmup':
lr = learning_rate.CosineDecayWithWarmup( lr = learning_rate.CosineDecayWithWarmup(
batch_size=batch_size, batch_size=batch_size,
...@@ -389,7 +398,6 @@ def build_learning_rate(params: base_configs.LearningRateConfig, ...@@ -389,7 +398,6 @@ def build_learning_rate(params: base_configs.LearningRateConfig,
if decay_type not in ['cosine_with_warmup']: if decay_type not in ['cosine_with_warmup']:
logging.info('Applying %d warmup steps to the learning rate', logging.info('Applying %d warmup steps to the learning rate',
warmup_steps) warmup_steps)
lr = learning_rate.WarmupDecaySchedule(lr, lr = learning_rate.WarmupDecaySchedule(
warmup_steps, lr, warmup_steps, warmup_lr=base_lr)
warmup_lr=base_lr)
return lr return lr
...@@ -35,10 +35,8 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -35,10 +35,8 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
return model return model
@parameterized.named_parameters( @parameterized.named_parameters(
('sgd', 'sgd', 0., False), ('sgd', 'sgd', 0., False), ('momentum', 'momentum', 0., False),
('momentum', 'momentum', 0., False), ('rmsprop', 'rmsprop', 0., False), ('adam', 'adam', 0., False),
('rmsprop', 'rmsprop', 0., False),
('adam', 'adam', 0., False),
('adamw', 'adamw', 0., False), ('adamw', 'adamw', 0., False),
('momentum_lookahead', 'momentum', 0., True), ('momentum_lookahead', 'momentum', 0., True),
('sgd_ema', 'sgd', 0.999, False), ('sgd_ema', 'sgd', 0.999, False),
...@@ -84,16 +82,13 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -84,16 +82,13 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
train_steps = 1 train_steps = 1
lr = optimizer_factory.build_learning_rate( lr = optimizer_factory.build_learning_rate(
params=params, params=params, batch_size=batch_size, train_steps=train_steps)
batch_size=batch_size,
train_steps=train_steps)
self.assertTrue( self.assertTrue(
issubclass( issubclass(
type(lr), tf.keras.optimizers.schedules.LearningRateSchedule)) type(lr), tf.keras.optimizers.schedules.LearningRateSchedule))
@parameterized.named_parameters( @parameterized.named_parameters(('exponential', 'exponential'),
('exponential', 'exponential'), ('cosine_with_warmup', 'cosine_with_warmup'))
('cosine_with_warmup', 'cosine_with_warmup'))
def test_learning_rate_with_decay_and_warmup(self, lr_decay_type): def test_learning_rate_with_decay_and_warmup(self, lr_decay_type):
"""Basic smoke test for syntax.""" """Basic smoke test for syntax."""
params = base_configs.LearningRateConfig( params = base_configs.LearningRateConfig(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment