Commit 88253ce5 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 326286926
parent 52371ffe
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils used to manipulate tensor shapes."""
import tensorflow as tf
......@@ -42,7 +41,8 @@ def assert_shape_equal(shape_a, shape_b):
all(isinstance(dim, int) for dim in shape_b)):
if shape_a != shape_b:
raise ValueError('Unequal shapes {}, {}'.format(shape_a, shape_b))
else: return tf.no_op()
else:
return tf.no_op()
else:
return tf.assert_equal(shape_a, shape_b)
......@@ -87,9 +87,7 @@ def pad_or_clip_nd(tensor, output_shape):
if shape is not None else -1 for i, shape in enumerate(output_shape)
]
clipped_tensor = tf.slice(
tensor,
begin=tf.zeros(len(clip_size), dtype=tf.int32),
size=clip_size)
tensor, begin=tf.zeros(len(clip_size), dtype=tf.int32), size=clip_size)
# Pad tensor if the shape of clipped tensor is smaller than the expected
# shape.
......@@ -99,10 +97,7 @@ def pad_or_clip_nd(tensor, output_shape):
for i, shape in enumerate(output_shape)
]
paddings = tf.stack(
[
tf.zeros(len(trailing_paddings), dtype=tf.int32),
trailing_paddings
],
[tf.zeros(len(trailing_paddings), dtype=tf.int32), trailing_paddings],
axis=1)
padded_tensor = tf.pad(tensor=clipped_tensor, paddings=paddings)
output_static_shape = [
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base target assigner module.
The job of a TargetAssigner is, for a given set of anchors (bounding boxes) and
......@@ -31,35 +30,39 @@ Note that TargetAssigners only operate on detections from a single
image at a time, so any logic for applying a TargetAssigner to multiple
images must be handled externally.
"""
import tensorflow as tf
from official.vision.detection.utils.object_detection import box_list
from official.vision.detection.utils.object_detection import shape_utils
KEYPOINTS_FIELD_NAME = 'keypoints'
class TargetAssigner(object):
"""Target assigner to compute classification and regression targets."""
def __init__(self, similarity_calc, matcher, box_coder,
negative_class_weight=1.0, unmatched_cls_target=None):
def __init__(self,
similarity_calc,
matcher,
box_coder,
negative_class_weight=1.0,
unmatched_cls_target=None):
"""Construct Object Detection Target Assigner.
Args:
similarity_calc: a RegionSimilarityCalculator
matcher: Matcher used to match groundtruth to anchors.
box_coder: BoxCoder used to encode matching groundtruth boxes with
respect to anchors.
box_coder: BoxCoder used to encode matching groundtruth boxes with respect
to anchors.
negative_class_weight: classification weight to be associated to negative
anchors (default: 1.0). The weight must be in [0., 1.].
unmatched_cls_target: a float32 tensor with shape [d_1, d_2, ..., d_k]
which is consistent with the classification target for each
anchor (and can be empty for scalar targets). This shape must thus be
compatible with the groundtruth labels that are passed to the "assign"
function (which have shape [num_gt_boxes, d_1, d_2, ..., d_k]).
If set to None, unmatched_cls_target is set to be [0] for each anchor.
which is consistent with the classification target for each anchor (and
can be empty for scalar targets). This shape must thus be compatible
with the groundtruth labels that are passed to the "assign" function
(which have shape [num_gt_boxes, d_1, d_2, ..., d_k]). If set to None,
unmatched_cls_target is set to be [0] for each anchor.
Raises:
ValueError: if similarity_calc is not a RegionSimilarityCalculator or
......@@ -78,8 +81,12 @@ class TargetAssigner(object):
def box_coder(self):
return self._box_coder
def assign(self, anchors, groundtruth_boxes, groundtruth_labels=None,
groundtruth_weights=None, **params):
def assign(self,
anchors,
groundtruth_boxes,
groundtruth_labels=None,
groundtruth_weights=None,
**params):
"""Assign classification and regression targets to each anchor.
For a given set of anchors and groundtruth detections, match anchors
......@@ -93,16 +100,16 @@ class TargetAssigner(object):
Args:
anchors: a BoxList representing N anchors
groundtruth_boxes: a BoxList representing M groundtruth boxes
groundtruth_labels: a tensor of shape [M, d_1, ... d_k]
with labels for each of the ground_truth boxes. The subshape
[d_1, ... d_k] can be empty (corresponding to scalar inputs). When set
to None, groundtruth_labels assumes a binary problem where all
ground_truth boxes get a positive label (of 1).
groundtruth_labels: a tensor of shape [M, d_1, ... d_k] with labels for
each of the ground_truth boxes. The subshape [d_1, ... d_k] can be empty
(corresponding to scalar inputs). When set to None, groundtruth_labels
assumes a binary problem where all ground_truth boxes get a positive
label (of 1).
groundtruth_weights: a float tensor of shape [M] indicating the weight to
assign to all anchors match to a particular groundtruth box. The weights
must be in [0., 1.]. If None, all weights are set to 1.
**params: Additional keyword arguments for specific implementations of
the Matcher.
**params: Additional keyword arguments for specific implementations of the
Matcher.
Returns:
cls_targets: a float32 tensor with shape [num_anchors, d_1, d_2 ... d_k],
......@@ -125,16 +132,15 @@ class TargetAssigner(object):
raise ValueError('groundtruth_boxes must be an BoxList')
if groundtruth_labels is None:
groundtruth_labels = tf.ones(tf.expand_dims(groundtruth_boxes.num_boxes(),
0))
groundtruth_labels = tf.ones(
tf.expand_dims(groundtruth_boxes.num_boxes(), 0))
groundtruth_labels = tf.expand_dims(groundtruth_labels, -1)
unmatched_shape_assert = shape_utils.assert_shape_equal(
shape_utils.combined_static_and_dynamic_shape(groundtruth_labels)[1:],
shape_utils.combined_static_and_dynamic_shape(
self._unmatched_cls_target))
labels_and_box_shapes_assert = shape_utils.assert_shape_equal(
shape_utils.combined_static_and_dynamic_shape(
groundtruth_labels)[:1],
shape_utils.combined_static_and_dynamic_shape(groundtruth_labels)[:1],
shape_utils.combined_static_and_dynamic_shape(
groundtruth_boxes.get())[:1])
......@@ -145,11 +151,10 @@ class TargetAssigner(object):
groundtruth_weights = tf.ones([num_gt_boxes], dtype=tf.float32)
with tf.control_dependencies(
[unmatched_shape_assert, labels_and_box_shapes_assert]):
match_quality_matrix = self._similarity_calc.compare(groundtruth_boxes,
anchors)
match_quality_matrix = self._similarity_calc.compare(
groundtruth_boxes, anchors)
match = self._matcher.match(match_quality_matrix, **params)
reg_targets = self._create_regression_targets(anchors,
groundtruth_boxes,
reg_targets = self._create_regression_targets(anchors, groundtruth_boxes,
match)
cls_targets = self._create_classification_targets(groundtruth_labels,
match)
......@@ -210,8 +215,8 @@ class TargetAssigner(object):
match.match_results)
# Zero out the unmatched and ignored regression targets.
unmatched_ignored_reg_targets = tf.tile(
self._default_regression_target(), [match_results_shape[0], 1])
unmatched_ignored_reg_targets = tf.tile(self._default_regression_target(),
[match_results_shape[0], 1])
matched_anchors_mask = match.matched_column_indicator()
# To broadcast matched_anchors_mask to the same shape as
# matched_reg_targets.
......@@ -233,7 +238,7 @@ class TargetAssigner(object):
Returns:
default_target: a float32 tensor with shape [1, box_code_dimension]
"""
return tf.constant([self._box_coder.code_size*[0]], tf.float32)
return tf.constant([self._box_coder.code_size * [0]], tf.float32)
def _create_classification_targets(self, groundtruth_labels, match):
"""Create classification targets for each anchor.
......@@ -243,11 +248,11 @@ class TargetAssigner(object):
to anything are given the target self._unmatched_cls_target
Args:
groundtruth_labels: a tensor of shape [num_gt_boxes, d_1, ... d_k]
with labels for each of the ground_truth boxes. The subshape
[d_1, ... d_k] can be empty (corresponding to scalar labels).
match: a matcher.Match object that provides a matching between anchors
and groundtruth boxes.
groundtruth_labels: a tensor of shape [num_gt_boxes, d_1, ... d_k] with
labels for each of the ground_truth boxes. The subshape [d_1, ... d_k]
can be empty (corresponding to scalar labels).
match: a matcher.Match object that provides a matching between anchors and
groundtruth boxes.
Returns:
a float32 tensor with shape [num_anchors, d_1, d_2 ... d_k], where the
......@@ -267,8 +272,8 @@ class TargetAssigner(object):
negative anchor.
Args:
match: a matcher.Match object that provides a matching between anchors
and groundtruth boxes.
match: a matcher.Match object that provides a matching between anchors and
groundtruth boxes.
groundtruth_weights: a float tensor of shape [M] indicating the weight to
assign to all anchors match to a particular groundtruth box.
......@@ -278,9 +283,7 @@ class TargetAssigner(object):
return match.gather_based_on_match(
groundtruth_weights, ignored_value=0., unmatched_value=0.)
def _create_classification_weights(self,
match,
groundtruth_weights):
def _create_classification_weights(self, match, groundtruth_weights):
"""Create classification weights for each anchor.
Positive (matched) anchors are associated with a weight of
......@@ -291,8 +294,8 @@ class TargetAssigner(object):
the case in object detection).
Args:
match: a matcher.Match object that provides a matching between anchors
and groundtruth boxes.
match: a matcher.Match object that provides a matching between anchors and
groundtruth boxes.
groundtruth_weights: a float tensor of shape [M] indicating the weight to
assign to all anchors match to a particular groundtruth box.
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A set of functions that are used for visualization.
These functions often receive an image, perform some visualization on the image.
......@@ -21,9 +20,11 @@ The functions do not return a value, instead they modify the image itself.
"""
import collections
import functools
from absl import logging
# Set headless-friendly backend.
import matplotlib; matplotlib.use('Agg') # pylint: disable=multiple-statements
import matplotlib
matplotlib.use('Agg') # pylint: disable=multiple-statements
import matplotlib.pyplot as plt # pylint: disable=g-import-not-at-top
import numpy as np
import PIL.Image as Image
......@@ -36,7 +37,6 @@ import tensorflow as tf
from official.vision.detection.utils import box_utils
from official.vision.detection.utils.object_detection import shape_utils
_TITLE_LEFT_MARGIN = 10
_TITLE_TOP_MARGIN = 10
STANDARD_COLORS = [
......@@ -99,9 +99,9 @@ def visualize_images_with_bounding_boxes(images, box_outputs, step,
summary_writer):
"""Records subset of evaluation images with bounding boxes."""
if not isinstance(images, list):
logging.warning('visualize_images_with_bounding_boxes expects list of '
'images but received type: %s and value: %s',
type(images), images)
logging.warning(
'visualize_images_with_bounding_boxes expects list of '
'images but received type: %s and value: %s', type(images), images)
return
image_shape = tf.shape(images[0])
......@@ -140,11 +140,11 @@ def draw_bounding_box_on_image_array(image,
xmax: xmax of bounding box.
color: color to draw bounding box. Default is red.
thickness: line thickness. Default value is 4.
display_str_list: list of strings to display in box
(each to be shown on its own line).
use_normalized_coordinates: If True (default), treat coordinates
ymin, xmin, ymax, xmax as relative to the image. Otherwise treat
coordinates as absolute.
display_str_list: list of strings to display in box (each to be shown on its
own line).
use_normalized_coordinates: If True (default), treat coordinates ymin, xmin,
ymax, xmax as relative to the image. Otherwise treat coordinates as
absolute.
"""
image_pil = Image.fromarray(np.uint8(image)).convert('RGB')
draw_bounding_box_on_image(image_pil, ymin, xmin, ymax, xmax, color,
......@@ -180,11 +180,11 @@ def draw_bounding_box_on_image(image,
xmax: xmax of bounding box.
color: color to draw bounding box. Default is red.
thickness: line thickness. Default value is 4.
display_str_list: list of strings to display in box
(each to be shown on its own line).
use_normalized_coordinates: If True (default), treat coordinates
ymin, xmin, ymax, xmax as relative to the image. Otherwise treat
coordinates as absolute.
display_str_list: list of strings to display in box (each to be shown on its
own line).
use_normalized_coordinates: If True (default), treat coordinates ymin, xmin,
ymax, xmax as relative to the image. Otherwise treat coordinates as
absolute.
"""
draw = ImageDraw.Draw(image)
im_width, im_height = image.size
......@@ -193,8 +193,10 @@ def draw_bounding_box_on_image(image,
ymin * im_height, ymax * im_height)
else:
(left, right, top, bottom) = (xmin, xmax, ymin, ymax)
draw.line([(left, top), (left, bottom), (right, bottom),
(right, top), (left, top)], width=thickness, fill=color)
draw.line([(left, top), (left, bottom), (right, bottom), (right, top),
(left, top)],
width=thickness,
fill=color)
try:
font = ImageFont.truetype('arial.ttf', 24)
except IOError:
......@@ -215,15 +217,13 @@ def draw_bounding_box_on_image(image,
for display_str in display_str_list[::-1]:
text_width, text_height = font.getsize(display_str)
margin = np.ceil(0.05 * text_height)
draw.rectangle(
[(left, text_bottom - text_height - 2 * margin), (left + text_width,
text_bottom)],
fill=color)
draw.text(
(left + margin, text_bottom - text_height - margin),
display_str,
fill='black',
font=font)
draw.rectangle([(left, text_bottom - text_height - 2 * margin),
(left + text_width, text_bottom)],
fill=color)
draw.text((left + margin, text_bottom - text_height - margin),
display_str,
fill='black',
font=font)
text_bottom -= text_height - 2 * margin
......@@ -236,15 +236,13 @@ def draw_bounding_boxes_on_image_array(image,
Args:
image: a numpy array object.
boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax).
The coordinates are in normalized format between [0, 1].
boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax). The
coordinates are in normalized format between [0, 1].
color: color to draw bounding box. Default is red.
thickness: line thickness. Default value is 4.
display_str_list_list: list of list of strings.
a list of strings for each bounding box.
The reason to pass a list of strings for a
bounding box is that it might contain
multiple labels.
display_str_list_list: list of list of strings. a list of strings for each
bounding box. The reason to pass a list of strings for a bounding box is
that it might contain multiple labels.
Raises:
ValueError: if boxes is not a [N, 4] array
......@@ -264,15 +262,13 @@ def draw_bounding_boxes_on_image(image,
Args:
image: a PIL.Image object.
boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax).
The coordinates are in normalized format between [0, 1].
boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax). The
coordinates are in normalized format between [0, 1].
color: color to draw bounding box. Default is red.
thickness: line thickness. Default value is 4.
display_str_list_list: list of list of strings.
a list of strings for each bounding box.
The reason to pass a list of strings for a
bounding box is that it might contain
multiple labels.
display_str_list_list: list of list of strings. a list of strings for each
bounding box. The reason to pass a list of strings for a bounding box is
that it might contain multiple labels.
Raises:
ValueError: if boxes is not a [N, 4] array
......@@ -319,8 +315,9 @@ def _visualize_boxes_and_keypoints(image, boxes, classes, scores, keypoints,
**kwargs)
def _visualize_boxes_and_masks_and_keypoints(
image, boxes, classes, scores, masks, keypoints, category_index, **kwargs):
def _visualize_boxes_and_masks_and_keypoints(image, boxes, classes, scores,
masks, keypoints, category_index,
**kwargs):
return visualize_boxes_and_labels_on_image_array(
image,
boxes,
......@@ -374,8 +371,8 @@ def draw_bounding_boxes_on_image_tensors(images,
max_boxes_to_draw: Maximum number of boxes to draw on an image. Default 20.
min_score_thresh: Minimum score threshold for visualization. Default 0.2.
use_normalized_coordinates: Whether to assume boxes and kepoints are in
normalized coordinates (as opposed to absolute coordiantes).
Default is True.
normalized coordinates (as opposed to absolute coordiantes). Default is
True.
Returns:
4D image tensor of type uint8, with boxes drawn on top.
......@@ -432,17 +429,15 @@ def draw_bounding_boxes_on_image_tensors(images,
_visualize_boxes,
category_index=category_index,
**visualization_keyword_args)
elems = [
true_shapes, original_shapes, images, boxes, classes, scores
]
elems = [true_shapes, original_shapes, images, boxes, classes, scores]
def draw_boxes(image_and_detections):
"""Draws boxes on image."""
true_shape = image_and_detections[0]
original_shape = image_and_detections[1]
if true_image_shape is not None:
image = shape_utils.pad_or_clip_nd(
image_and_detections[2], [true_shape[0], true_shape[1], 3])
image = shape_utils.pad_or_clip_nd(image_and_detections[2],
[true_shape[0], true_shape[1], 3])
if original_image_spatial_shape is not None:
image_and_detections[2] = _resize_original_image(image, original_shape)
......@@ -500,7 +495,8 @@ def draw_keypoints_on_image(image,
for keypoint_x, keypoint_y in zip(keypoints_x, keypoints_y):
draw.ellipse([(keypoint_x - radius, keypoint_y - radius),
(keypoint_x + radius, keypoint_y + radius)],
outline=color, fill=color)
outline=color,
fill=color)
def draw_mask_on_image_array(image, mask, color='red', alpha=0.4):
......@@ -508,8 +504,8 @@ def draw_mask_on_image_array(image, mask, color='red', alpha=0.4):
Args:
image: uint8 numpy array with shape (img_height, img_height, 3)
mask: a uint8 numpy array of shape (img_height, img_height) with
values between either 0 or 1.
mask: a uint8 numpy array of shape (img_height, img_height) with values
between either 0 or 1.
color: color to draw the keypoints with. Default is red.
alpha: transparency value between 0 and 1. (default: 0.4)
......@@ -531,7 +527,7 @@ def draw_mask_on_image_array(image, mask, color='red', alpha=0.4):
solid_color = np.expand_dims(
np.ones_like(mask), axis=2) * np.reshape(list(rgb), [1, 1, 3])
pil_solid_color = Image.fromarray(np.uint8(solid_color)).convert('RGBA')
pil_mask = Image.fromarray(np.uint8(255.0*alpha*mask)).convert('L')
pil_mask = Image.fromarray(np.uint8(255.0 * alpha * mask)).convert('L')
pil_image = Image.composite(pil_solid_color, pil_image, pil_mask)
np.copyto(image, np.array(pil_image.convert('RGB')))
......@@ -565,21 +561,20 @@ def visualize_boxes_and_labels_on_image_array(
boxes: a numpy array of shape [N, 4]
classes: a numpy array of shape [N]. Note that class indices are 1-based,
and match the keys in the label map.
scores: a numpy array of shape [N] or None. If scores=None, then
this function assumes that the boxes to be plotted are groundtruth
boxes and plot all boxes as black with no classes or scores.
scores: a numpy array of shape [N] or None. If scores=None, then this
function assumes that the boxes to be plotted are groundtruth boxes and
plot all boxes as black with no classes or scores.
category_index: a dict containing category dictionaries (each holding
category index `id` and category name `name`) keyed by category indices.
instance_masks: a numpy array of shape [N, image_height, image_width] with
values ranging between 0 and 1, can be None.
instance_boundaries: a numpy array of shape [N, image_height, image_width]
with values ranging between 0 and 1, can be None.
keypoints: a numpy array of shape [N, num_keypoints, 2], can
be None
use_normalized_coordinates: whether boxes is to be interpreted as
normalized coordinates or not.
max_boxes_to_draw: maximum number of boxes to visualize. If None, draw
all boxes.
keypoints: a numpy array of shape [N, num_keypoints, 2], can be None
use_normalized_coordinates: whether boxes is to be interpreted as normalized
coordinates or not.
max_boxes_to_draw: maximum number of boxes to visualize. If None, draw all
boxes.
min_score_thresh: minimum score threshold for a box to be visualized
agnostic_mode: boolean (default: False) controlling whether to evaluate in
class-agnostic mode or not. This mode will display scores but ignore
......@@ -624,32 +619,25 @@ def visualize_boxes_and_labels_on_image_array(
display_str = str(class_name)
if not skip_scores:
if not display_str:
display_str = '{}%'.format(int(100*scores[i]))
display_str = '{}%'.format(int(100 * scores[i]))
else:
display_str = '{}: {}%'.format(display_str, int(100*scores[i]))
display_str = '{}: {}%'.format(display_str, int(100 * scores[i]))
box_to_display_str_map[box].append(display_str)
if agnostic_mode:
box_to_color_map[box] = 'DarkOrange'
else:
box_to_color_map[box] = STANDARD_COLORS[
classes[i] % len(STANDARD_COLORS)]
box_to_color_map[box] = STANDARD_COLORS[classes[i] %
len(STANDARD_COLORS)]
# Draw all boxes onto image.
for box, color in box_to_color_map.items():
ymin, xmin, ymax, xmax = box
if instance_masks is not None:
draw_mask_on_image_array(
image,
box_to_instance_masks_map[box],
color=color
)
image, box_to_instance_masks_map[box], color=color)
if instance_boundaries is not None:
draw_mask_on_image_array(
image,
box_to_instance_boundaries_map[box],
color='red',
alpha=1.0
)
image, box_to_instance_boundaries_map[box], color='red', alpha=1.0)
draw_bounding_box_on_image_array(
image,
ymin,
......@@ -681,13 +669,15 @@ def add_cdf_image_summary(values, name):
values: a 1-D float32 tensor containing the values.
name: name for the image summary.
"""
def cdf_plot(values):
"""Numpy function to plot CDF."""
normalized_values = values / np.sum(values)
sorted_values = np.sort(normalized_values)
cumulative_values = np.cumsum(sorted_values)
fraction_of_examples = (np.arange(cumulative_values.size, dtype=np.float32)
/ cumulative_values.size)
fraction_of_examples = (
np.arange(cumulative_values.size, dtype=np.float32) /
cumulative_values.size)
fig = plt.figure(frameon=False)
ax = fig.add_subplot('111')
ax.plot(fraction_of_examples, cumulative_values)
......@@ -695,8 +685,9 @@ def add_cdf_image_summary(values, name):
ax.set_xlabel('fraction of examples')
fig.canvas.draw()
width, height = fig.get_size_inches() * fig.get_dpi()
image = np.fromstring(fig.canvas.tostring_rgb(), dtype='uint8').reshape(
1, int(height), int(width), 3)
image = np.fromstring(
fig.canvas.tostring_rgb(),
dtype='uint8').reshape(1, int(height), int(width), 3)
return image
cdf_plot = tf.compat.v1.py_func(cdf_plot, [values], tf.uint8)
......@@ -725,8 +716,8 @@ def add_hist_image_summary(values, bins, name):
fig.canvas.draw()
width, height = fig.get_size_inches() * fig.get_dpi()
image = np.fromstring(
fig.canvas.tostring_rgb(), dtype='uint8').reshape(
1, int(height), int(width), 3)
fig.canvas.tostring_rgb(),
dtype='uint8').reshape(1, int(height), int(width), 3)
return image
hist_plot = tf.compat.v1.py_func(hist_plot, [values, bins], tf.uint8)
......
......@@ -24,6 +24,7 @@ from __future__ import division
from __future__ import print_function
import math
import tensorflow as tf
from typing import Any, Dict, List, Optional, Text, Tuple
......@@ -120,10 +121,8 @@ def _convert_translation_to_transform(translations: tf.Tensor) -> tf.Tensor:
)
def _convert_angles_to_transform(
angles: tf.Tensor,
image_width: tf.Tensor,
image_height: tf.Tensor) -> tf.Tensor:
def _convert_angles_to_transform(angles: tf.Tensor, image_width: tf.Tensor,
image_height: tf.Tensor) -> tf.Tensor:
"""Converts an angle or angles to a projective transform.
Args:
......@@ -173,9 +172,7 @@ def transform(image: tf.Tensor, transforms) -> tf.Tensor:
transforms = transforms[None]
image = to_4d(image)
image = image_ops.transform(
images=image,
transforms=transforms,
interpolation='nearest')
images=image, transforms=transforms, interpolation='nearest')
return from_4d(image, original_ndims)
......@@ -216,9 +213,8 @@ def rotate(image: tf.Tensor, degrees: float) -> tf.Tensor:
image_height = tf.cast(tf.shape(image)[1], tf.float32)
image_width = tf.cast(tf.shape(image)[2], tf.float32)
transforms = _convert_angles_to_transform(angles=radians,
image_width=image_width,
image_height=image_height)
transforms = _convert_angles_to_transform(
angles=radians, image_width=image_width, image_height=image_height)
# In practice, we should randomize the rotation degrees by flipping
# it negatively half the time, but that's done on 'degrees' outside
# of the function.
......@@ -279,11 +275,10 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
Args:
image: An image Tensor of type uint8.
pad_size: Specifies how big the zero mask that will be generated is that
is applied to the image. The mask will be of size
(2*pad_size x 2*pad_size).
replace: What pixel value to fill in the image in the area that has
the cutout mask applied to it.
pad_size: Specifies how big the zero mask that will be generated is that is
applied to the image. The mask will be of size (2*pad_size x 2*pad_size).
replace: What pixel value to fill in the image in the area that has the
cutout mask applied to it.
Returns:
An image Tensor that is of type uint8.
......@@ -293,30 +288,30 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
# Sample the center location in the image where the zero mask will be applied.
cutout_center_height = tf.random.uniform(
shape=[], minval=0, maxval=image_height,
dtype=tf.int32)
shape=[], minval=0, maxval=image_height, dtype=tf.int32)
cutout_center_width = tf.random.uniform(
shape=[], minval=0, maxval=image_width,
dtype=tf.int32)
shape=[], minval=0, maxval=image_width, dtype=tf.int32)
lower_pad = tf.maximum(0, cutout_center_height - pad_size)
upper_pad = tf.maximum(0, image_height - cutout_center_height - pad_size)
left_pad = tf.maximum(0, cutout_center_width - pad_size)
right_pad = tf.maximum(0, image_width - cutout_center_width - pad_size)
cutout_shape = [image_height - (lower_pad + upper_pad),
image_width - (left_pad + right_pad)]
cutout_shape = [
image_height - (lower_pad + upper_pad),
image_width - (left_pad + right_pad)
]
padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]]
mask = tf.pad(
tf.zeros(cutout_shape, dtype=image.dtype),
padding_dims, constant_values=1)
padding_dims,
constant_values=1)
mask = tf.expand_dims(mask, -1)
mask = tf.tile(mask, [1, 1, 3])
image = tf.where(
tf.equal(mask, 0),
tf.ones_like(image, dtype=image.dtype) * replace,
image)
tf.ones_like(image, dtype=image.dtype) * replace, image)
return image
......@@ -398,8 +393,8 @@ def shear_x(image: tf.Tensor, level: float, replace: int) -> tf.Tensor:
# with a matrix form of:
# [1 level
# 0 1].
image = transform(image=wrap(image),
transforms=[1., level, 0., 0., 1., 0., 0., 0.])
image = transform(
image=wrap(image), transforms=[1., level, 0., 0., 1., 0., 0., 0.])
return unwrap(image, replace)
......@@ -409,8 +404,8 @@ def shear_y(image: tf.Tensor, level: float, replace: int) -> tf.Tensor:
# with a matrix form of:
# [1 0
# level 1].
image = transform(image=wrap(image),
transforms=[1., 0., 0., level, 1., 0., 0., 0.])
image = transform(
image=wrap(image), transforms=[1., 0., 0., level, 1., 0., 0., 0.])
return unwrap(image, replace)
......@@ -460,9 +455,9 @@ def sharpness(image: tf.Tensor, factor: float) -> tf.Tensor:
# Make image 4D for conv operation.
image = tf.expand_dims(image, 0)
# SMOOTH PIL Kernel.
kernel = tf.constant(
[[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=tf.float32,
shape=[3, 3, 1, 1]) / 13.
kernel = tf.constant([[1, 1, 1], [1, 5, 1], [1, 1, 1]],
dtype=tf.float32,
shape=[3, 3, 1, 1]) / 13.
# Tile across channel dimension.
kernel = tf.tile(kernel, [1, 1, 3, 1])
strides = [1, 1, 1, 1]
......@@ -484,6 +479,7 @@ def sharpness(image: tf.Tensor, factor: float) -> tf.Tensor:
def equalize(image: tf.Tensor) -> tf.Tensor:
"""Implements Equalize function from PIL using TF ops."""
def scale_channel(im, c):
"""Scale the data in the channel to implement equalize."""
im = tf.cast(im[:, :, c], tf.int32)
......@@ -507,9 +503,9 @@ def equalize(image: tf.Tensor) -> tf.Tensor:
# If step is zero, return the original image. Otherwise, build
# lut from the full histogram and step and then index from it.
result = tf.cond(tf.equal(step, 0),
lambda: im,
lambda: tf.gather(build_lut(histo, step), im))
result = tf.cond(
tf.equal(step, 0), lambda: im,
lambda: tf.gather(build_lut(histo, step), im))
return tf.cast(result, tf.uint8)
......@@ -582,7 +578,7 @@ def _randomly_negate_tensor(tensor):
def _rotate_level_to_arg(level: float):
level = (level/_MAX_LEVEL) * 30.
level = (level / _MAX_LEVEL) * 30.
level = _randomly_negate_tensor(level)
return (level,)
......@@ -597,18 +593,18 @@ def _shrink_level_to_arg(level: float):
def _enhance_level_to_arg(level: float):
return ((level/_MAX_LEVEL) * 1.8 + 0.1,)
return ((level / _MAX_LEVEL) * 1.8 + 0.1,)
def _shear_level_to_arg(level: float):
level = (level/_MAX_LEVEL) * 0.3
level = (level / _MAX_LEVEL) * 0.3
# Flip level to negative with 50% chance.
level = _randomly_negate_tensor(level)
return (level,)
def _translate_level_to_arg(level: float, translate_const: float):
level = (level/_MAX_LEVEL) * float(translate_const)
level = (level / _MAX_LEVEL) * float(translate_const)
# Flip level to negative with 50% chance.
level = _randomly_negate_tensor(level)
return (level,)
......@@ -618,20 +614,15 @@ def _mult_to_arg(level: float, multiplier: float = 1.):
return (int((level / _MAX_LEVEL) * multiplier),)
def _apply_func_with_prob(func: Any,
image: tf.Tensor,
args: Any,
prob: float):
def _apply_func_with_prob(func: Any, image: tf.Tensor, args: Any, prob: float):
"""Apply `func` to image w/ `args` as input with probability `prob`."""
assert isinstance(args, tuple)
# Apply the function with probability `prob`.
should_apply_op = tf.cast(
tf.floor(tf.random.uniform([], dtype=tf.float32) + prob), tf.bool)
augmented_image = tf.cond(
should_apply_op,
lambda: func(image, *args),
lambda: image)
augmented_image = tf.cond(should_apply_op, lambda: func(image, *args),
lambda: image)
return augmented_image
......@@ -709,11 +700,8 @@ def level_to_arg(cutout_const: float, translate_const: float):
return args
def _parse_policy_info(name: Text,
prob: float,
level: float,
replace_value: List[int],
cutout_const: float,
def _parse_policy_info(name: Text, prob: float, level: float,
replace_value: List[int], cutout_const: float,
translate_const: float) -> Tuple[Any, float, Any]:
"""Return the function that corresponds to `name` and update `level` param."""
func = NAME_TO_FUNC[name]
......@@ -969,8 +957,9 @@ class RandAugment(ImageAugment):
min_prob, max_prob = 0.2, 0.8
for _ in range(self.num_layers):
op_to_select = tf.random.uniform(
[], maxval=len(self.available_ops) + 1, dtype=tf.int32)
op_to_select = tf.random.uniform([],
maxval=len(self.available_ops) + 1,
dtype=tf.int32)
branch_fns = []
for (i, op_name) in enumerate(self.available_ops):
......@@ -978,11 +967,8 @@ class RandAugment(ImageAugment):
minval=min_prob,
maxval=max_prob,
dtype=tf.float32)
func, _, args = _parse_policy_info(op_name,
prob,
self.magnitude,
replace_value,
self.cutout_const,
func, _, args = _parse_policy_info(op_name, prob, self.magnitude,
replace_value, self.cutout_const,
self.translate_const)
branch_fns.append((
i,
......@@ -991,9 +977,10 @@ class RandAugment(ImageAugment):
image, *selected_args)))
# pylint:enable=g-long-lambda
image = tf.switch_case(branch_index=op_to_select,
branch_fns=branch_fns,
default=lambda: tf.identity(image))
image = tf.switch_case(
branch_index=op_to_select,
branch_fns=branch_fns,
default=lambda: tf.identity(image))
image = tf.cast(image, dtype=input_image_type)
return image
......@@ -49,24 +49,15 @@ class TransformsTest(parameterized.TestCase, tf.test.TestCase):
def test_transform(self, dtype):
image = tf.constant([[1, 2], [3, 4]], dtype=dtype)
self.assertAllEqual(augment.transform(image, transforms=[1]*8),
[[4, 4], [4, 4]])
self.assertAllEqual(
augment.transform(image, transforms=[1] * 8), [[4, 4], [4, 4]])
def test_translate(self, dtype):
image = tf.constant(
[[1, 0, 1, 0],
[0, 1, 0, 1],
[1, 0, 1, 0],
[0, 1, 0, 1]],
dtype=dtype)
[[1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]], dtype=dtype)
translations = [-1, -1]
translated = augment.translate(image=image,
translations=translations)
expected = [
[1, 0, 1, 1],
[0, 1, 0, 0],
[1, 0, 1, 1],
[1, 0, 1, 1]]
translated = augment.translate(image=image, translations=translations)
expected = [[1, 0, 1, 1], [0, 1, 0, 0], [1, 0, 1, 1], [1, 0, 1, 1]]
self.assertAllEqual(translated, expected)
def test_translate_shapes(self, dtype):
......@@ -85,9 +76,7 @@ class TransformsTest(parameterized.TestCase, tf.test.TestCase):
image = tf.reshape(tf.cast(tf.range(9), dtype), (3, 3))
rotation = 90.
transformed = augment.rotate(image=image, degrees=rotation)
expected = [[2, 5, 8],
[1, 4, 7],
[0, 3, 6]]
expected = [[2, 5, 8], [1, 4, 7], [0, 3, 6]]
self.assertAllEqual(transformed, expected)
def test_rotate_shapes(self, dtype):
......@@ -129,15 +118,13 @@ class AutoaugmentTest(tf.test.TestCase):
image = tf.ones((224, 224, 3), dtype=tf.uint8)
for op_name in augment.NAME_TO_FUNC:
func, _, args = augment._parse_policy_info(op_name,
prob,
magnitude,
replace_value,
cutout_const,
func, _, args = augment._parse_policy_info(op_name, prob, magnitude,
replace_value, cutout_const,
translate_const)
image = func(image, *args)
self.assertEqual((224, 224, 3), image.shape)
if __name__ == '__main__':
tf.test.main()
......@@ -21,6 +21,7 @@ from __future__ import print_function
import os
from typing import Any, List, MutableMapping, Text
from absl import logging
import tensorflow as tf
......@@ -43,8 +44,9 @@ def get_callbacks(model_checkpoint: bool = True,
callbacks = []
if model_checkpoint:
ckpt_full_path = os.path.join(model_dir, 'model.ckpt-{epoch:04d}')
callbacks.append(tf.keras.callbacks.ModelCheckpoint(
ckpt_full_path, save_weights_only=True, verbose=1))
callbacks.append(
tf.keras.callbacks.ModelCheckpoint(
ckpt_full_path, save_weights_only=True, verbose=1))
if include_tensorboard:
callbacks.append(
CustomTensorBoard(
......@@ -61,13 +63,14 @@ def get_callbacks(model_checkpoint: bool = True,
if apply_moving_average:
# Save moving average model to a different file so that
# we can resume training from a checkpoint
ckpt_full_path = os.path.join(
model_dir, 'average', 'model.ckpt-{epoch:04d}')
callbacks.append(AverageModelCheckpoint(
update_weights=False,
filepath=ckpt_full_path,
save_weights_only=True,
verbose=1))
ckpt_full_path = os.path.join(model_dir, 'average',
'model.ckpt-{epoch:04d}')
callbacks.append(
AverageModelCheckpoint(
update_weights=False,
filepath=ckpt_full_path,
save_weights_only=True,
verbose=1))
callbacks.append(MovingAverageCallback())
return callbacks
......@@ -175,16 +178,13 @@ class MovingAverageCallback(tf.keras.callbacks.Callback):
**kwargs: Any additional callback arguments.
"""
def __init__(self,
overwrite_weights_on_train_end: bool = False,
**kwargs):
def __init__(self, overwrite_weights_on_train_end: bool = False, **kwargs):
super(MovingAverageCallback, self).__init__(**kwargs)
self.overwrite_weights_on_train_end = overwrite_weights_on_train_end
def set_model(self, model: tf.keras.Model):
super(MovingAverageCallback, self).set_model(model)
assert isinstance(self.model.optimizer,
optimizer_factory.MovingAverage)
assert isinstance(self.model.optimizer, optimizer_factory.MovingAverage)
self.model.optimizer.shadow_copy(self.model)
def on_test_begin(self, logs: MutableMapping[Text, Any] = None):
......@@ -204,40 +204,30 @@ class AverageModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):
Taken from tfa.callbacks.AverageModelCheckpoint.
Attributes:
update_weights: If True, assign the moving average weights
to the model, and save them. If False, keep the old
non-averaged weights, but the saved model uses the
average weights.
See `tf.keras.callbacks.ModelCheckpoint` for the other args.
update_weights: If True, assign the moving average weights to the model, and
save them. If False, keep the old non-averaged weights, but the saved
model uses the average weights. See `tf.keras.callbacks.ModelCheckpoint`
for the other args.
"""
def __init__(
self,
update_weights: bool,
filepath: str,
monitor: str = 'val_loss',
verbose: int = 0,
save_best_only: bool = False,
save_weights_only: bool = False,
mode: str = 'auto',
save_freq: str = 'epoch',
**kwargs):
def __init__(self,
update_weights: bool,
filepath: str,
monitor: str = 'val_loss',
verbose: int = 0,
save_best_only: bool = False,
save_weights_only: bool = False,
mode: str = 'auto',
save_freq: str = 'epoch',
**kwargs):
self.update_weights = update_weights
super().__init__(
filepath,
monitor,
verbose,
save_best_only,
save_weights_only,
mode,
save_freq,
**kwargs)
super().__init__(filepath, monitor, verbose, save_best_only,
save_weights_only, mode, save_freq, **kwargs)
def set_model(self, model):
if not isinstance(model.optimizer, optimizer_factory.MovingAverage):
raise TypeError(
'AverageModelCheckpoint is only used when training'
'with MovingAverage')
raise TypeError('AverageModelCheckpoint is only used when training'
'with MovingAverage')
return super().set_model(model)
def _save_model(self, epoch, logs):
......
......@@ -41,7 +41,7 @@ from official.vision.image_classification.resnet import resnet_model
def get_models() -> Mapping[str, tf.keras.Model]:
"""Returns the mapping from model type name to Keras model."""
return {
return {
'efficientnet': efficientnet_model.EfficientNet.from_name,
'resnet': resnet_model.resnet50,
}
......@@ -55,7 +55,7 @@ def get_dtype_map() -> Mapping[str, tf.dtypes.DType]:
'float16': tf.float16,
'fp32': tf.float32,
'bf16': tf.bfloat16,
}
}
def _get_metrics(one_hot: bool) -> Mapping[Text, Any]:
......@@ -63,22 +63,28 @@ def _get_metrics(one_hot: bool) -> Mapping[Text, Any]:
if one_hot:
return {
# (name, metric_fn)
'acc': tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
'accuracy': tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
'top_1': tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
'top_5': tf.keras.metrics.TopKCategoricalAccuracy(
k=5,
name='top_5_accuracy'),
'acc':
tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
'accuracy':
tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
'top_1':
tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
'top_5':
tf.keras.metrics.TopKCategoricalAccuracy(
k=5, name='top_5_accuracy'),
}
else:
return {
# (name, metric_fn)
'acc': tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
'accuracy': tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
'top_1': tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
'top_5': tf.keras.metrics.SparseTopKCategoricalAccuracy(
k=5,
name='top_5_accuracy'),
'acc':
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
'accuracy':
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
'top_1':
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
'top_5':
tf.keras.metrics.SparseTopKCategoricalAccuracy(
k=5, name='top_5_accuracy'),
}
......@@ -94,8 +100,7 @@ def get_image_size_from_model(
def _get_dataset_builders(params: base_configs.ExperimentConfig,
strategy: tf.distribute.Strategy,
one_hot: bool
) -> Tuple[Any, Any]:
one_hot: bool) -> Tuple[Any, Any]:
"""Create and return train and validation dataset builders."""
if one_hot:
logging.warning('label_smoothing > 0, so datasets will be one hot encoded.')
......@@ -107,9 +112,7 @@ def _get_dataset_builders(params: base_configs.ExperimentConfig,
image_size = get_image_size_from_model(params)
dataset_configs = [
params.train_dataset, params.validation_dataset
]
dataset_configs = [params.train_dataset, params.validation_dataset]
builders = []
for config in dataset_configs:
......@@ -171,8 +174,7 @@ def _get_params_from_flags(flags_obj: flags.FlagValues):
},
}
overriding_configs = (flags_obj.config_file,
flags_obj.params_override,
overriding_configs = (flags_obj.config_file, flags_obj.params_override,
flags_overrides)
pp = pprint.PrettyPrinter()
......@@ -190,8 +192,7 @@ def _get_params_from_flags(flags_obj: flags.FlagValues):
return params
def resume_from_checkpoint(model: tf.keras.Model,
model_dir: str,
def resume_from_checkpoint(model: tf.keras.Model, model_dir: str,
train_steps: int) -> int:
"""Resumes from the latest checkpoint, if possible.
......@@ -226,8 +227,7 @@ def resume_from_checkpoint(model: tf.keras.Model,
def initialize(params: base_configs.ExperimentConfig,
dataset_builder: dataset_factory.DatasetBuilder):
"""Initializes backend related initializations."""
keras_utils.set_session_config(
enable_xla=params.runtime.enable_xla)
keras_utils.set_session_config(enable_xla=params.runtime.enable_xla)
performance.set_mixed_precision_policy(dataset_builder.dtype,
get_loss_scale(params))
if tf.config.list_physical_devices('GPU'):
......@@ -244,7 +244,8 @@ def initialize(params: base_configs.ExperimentConfig,
per_gpu_thread_count=params.runtime.per_gpu_thread_count,
gpu_thread_mode=params.runtime.gpu_thread_mode,
num_gpus=params.runtime.num_gpus,
datasets_num_private_threads=params.runtime.dataset_num_private_threads) # pylint:disable=line-too-long
datasets_num_private_threads=params.runtime
.dataset_num_private_threads) # pylint:disable=line-too-long
if params.runtime.batchnorm_spatial_persistent:
os.environ['TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'] = '1'
......@@ -253,9 +254,7 @@ def define_classifier_flags():
"""Defines common flags for image classification."""
hyperparams_flags.initialize_common_flags()
flags.DEFINE_string(
'data_dir',
default=None,
help='The location of the input data.')
'data_dir', default=None, help='The location of the input data.')
flags.DEFINE_string(
'mode',
default=None,
......@@ -278,8 +277,7 @@ def define_classifier_flags():
help='The interval of steps between logging of batch level stats.')
def serialize_config(params: base_configs.ExperimentConfig,
model_dir: str):
def serialize_config(params: base_configs.ExperimentConfig, model_dir: str):
"""Serializes and saves the experiment config."""
params_save_path = os.path.join(model_dir, 'params.yaml')
logging.info('Saving experiment configuration to %s', params_save_path)
......@@ -293,9 +291,8 @@ def train_and_eval(
"""Runs the train and eval path using compile/fit."""
logging.info('Running train and eval.')
distribution_utils.configure_cluster(
params.runtime.worker_hosts,
params.runtime.task_index)
distribution_utils.configure_cluster(params.runtime.worker_hosts,
params.runtime.task_index)
# Note: for TPUs, strategy and scope should be created before the dataset
strategy = strategy_override or distribution_utils.get_distribution_strategy(
......@@ -313,8 +310,9 @@ def train_and_eval(
one_hot = label_smoothing and label_smoothing > 0
builders = _get_dataset_builders(params, strategy, one_hot)
datasets = [builder.build(strategy)
if builder else None for builder in builders]
datasets = [
builder.build(strategy) if builder else None for builder in builders
]
# Unpack datasets and builders based on train/val/test splits
train_builder, validation_builder = builders # pylint: disable=unbalanced-tuple-unpacking
......@@ -351,16 +349,16 @@ def train_and_eval(
label_smoothing=params.model.loss.label_smoothing)
else:
loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()
model.compile(optimizer=optimizer,
loss=loss_obj,
metrics=metrics,
experimental_steps_per_execution=steps_per_loop)
model.compile(
optimizer=optimizer,
loss=loss_obj,
metrics=metrics,
experimental_steps_per_execution=steps_per_loop)
initial_epoch = 0
if params.train.resume_checkpoint:
initial_epoch = resume_from_checkpoint(model=model,
model_dir=params.model_dir,
train_steps=train_steps)
initial_epoch = resume_from_checkpoint(
model=model, model_dir=params.model_dir, train_steps=train_steps)
callbacks = custom_callbacks.get_callbacks(
model_checkpoint=params.train.callbacks.enable_checkpoint_and_export,
......@@ -399,9 +397,7 @@ def train_and_eval(
validation_dataset, steps=validation_steps, verbose=2)
# TODO(dankondratyuk): eval and save final test accuracy
stats = common.build_stats(history,
validation_output,
callbacks)
stats = common.build_stats(history, validation_output, callbacks)
return stats
......
......@@ -105,14 +105,13 @@ def get_trivial_model(num_classes: int) -> tf.keras.Model:
lr = 0.01
optimizer = tf.keras.optimizers.SGD(learning_rate=lr)
loss_obj = tf.keras.losses.SparseCategoricalCrossentropy()
model.compile(optimizer=optimizer,
loss=loss_obj,
run_eagerly=True)
model.compile(optimizer=optimizer, loss=loss_obj, run_eagerly=True)
return model
def get_trivial_data() -> tf.data.Dataset:
"""Gets trivial data in the ImageNet size."""
def generate_data(_) -> tf.data.Dataset:
image = tf.zeros(shape=(224, 224, 3), dtype=tf.float32)
label = tf.zeros([1], dtype=tf.int32)
......@@ -120,8 +119,8 @@ def get_trivial_data() -> tf.data.Dataset:
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(generate_data,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.map(
generate_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.prefetch(buffer_size=1).batch(1)
return dataset
......@@ -165,11 +164,10 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
'--mode=train_and_eval',
]
run = functools.partial(classifier_trainer.run,
strategy_override=distribution)
run_end_to_end(main=run,
extra_flags=train_and_eval_flags,
model_dir=model_dir)
run = functools.partial(
classifier_trainer.run, strategy_override=distribution)
run_end_to_end(
main=run, extra_flags=train_and_eval_flags, model_dir=model_dir)
@combinations.generate(
combinations.combine(
......@@ -209,29 +207,26 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
get_params_override(export_params)
]
run = functools.partial(classifier_trainer.run,
strategy_override=distribution)
run_end_to_end(main=run,
extra_flags=train_and_eval_flags,
model_dir=model_dir)
run_end_to_end(main=run,
extra_flags=export_flags,
model_dir=model_dir)
run = functools.partial(
classifier_trainer.run, strategy_override=distribution)
run_end_to_end(
main=run, extra_flags=train_and_eval_flags, model_dir=model_dir)
run_end_to_end(main=run, extra_flags=export_flags, model_dir=model_dir)
self.assertTrue(os.path.exists(export_path))
@combinations.generate(
combinations.combine(
distribution=[
strategy_combinations.tpu_strategy,
],
model=[
'efficientnet',
'resnet',
],
mode='eager',
dataset='imagenet',
dtype='bfloat16',
))
distribution=[
strategy_combinations.tpu_strategy,
],
model=[
'efficientnet',
'resnet',
],
mode='eager',
dataset='imagenet',
dtype='bfloat16',
))
def test_tpu_train(self, distribution, model, dataset, dtype):
"""Test train_and_eval and export for Keras classifier models."""
# Some parameters are not defined as flags (e.g. cannot run
......@@ -248,11 +243,10 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
'--mode=train_and_eval',
]
run = functools.partial(classifier_trainer.run,
strategy_override=distribution)
run_end_to_end(main=run,
extra_flags=train_and_eval_flags,
model_dir=model_dir)
run = functools.partial(
classifier_trainer.run, strategy_override=distribution)
run_end_to_end(
main=run, extra_flags=train_and_eval_flags, model_dir=model_dir)
@combinations.generate(distribution_strategy_combinations())
def test_end_to_end_invalid_mode(self, distribution, model, dataset):
......@@ -266,8 +260,8 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
get_params_override(basic_params_override()),
]
run = functools.partial(classifier_trainer.run,
strategy_override=distribution)
run = functools.partial(
classifier_trainer.run, strategy_override=distribution)
with self.assertRaises(ValueError):
run_end_to_end(main=run, extra_flags=extra_flags, model_dir=model_dir)
......@@ -292,9 +286,7 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
model=base_configs.ModelConfig(
model_params={
'model_name': model_name,
},
)
)
},))
size = classifier_trainer.get_image_size_from_model(config)
self.assertEqual(size, expected)
......@@ -306,16 +298,13 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
)
def test_get_loss_scale(self, loss_scale, dtype, expected):
config = base_configs.ExperimentConfig(
runtime=base_configs.RuntimeConfig(
loss_scale=loss_scale),
runtime=base_configs.RuntimeConfig(loss_scale=loss_scale),
train_dataset=dataset_factory.DatasetConfig(dtype=dtype))
ls = classifier_trainer.get_loss_scale(config, fp16_default=128)
self.assertEqual(ls, expected)
@parameterized.named_parameters(
('float16', 'float16'),
('bfloat16', 'bfloat16')
)
@parameterized.named_parameters(('float16', 'float16'),
('bfloat16', 'bfloat16'))
def test_initialize(self, dtype):
config = base_configs.ExperimentConfig(
runtime=base_configs.RuntimeConfig(
......@@ -332,6 +321,7 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
class EmptyClass:
pass
fake_ds_builder = EmptyClass()
fake_ds_builder.dtype = dtype
fake_ds_builder.config = EmptyClass()
......@@ -366,9 +356,7 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
clean_model = get_trivial_model(10)
weights_before_load = copy.deepcopy(clean_model.get_weights())
initial_epoch = classifier_trainer.resume_from_checkpoint(
model=clean_model,
model_dir=model_dir,
train_steps=train_steps)
model=clean_model, model_dir=model_dir, train_steps=train_steps)
self.assertEqual(initial_epoch, 1)
self.assertNotAllClose(weights_before_load, clean_model.get_weights())
......@@ -383,5 +371,6 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
self.assertTrue(os.path.exists(saved_params_path))
tf.io.gfile.rmtree(model_dir)
if __name__ == '__main__':
tf.test.main()
......@@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from typing import Any, List, Mapping, Optional
import dataclasses
......
......@@ -37,7 +37,6 @@ class EfficientNetImageNetConfig(base_configs.ExperimentConfig):
train: A `TrainConfig` instance.
evaluation: An `EvalConfig` instance.
model: A `ModelConfig` instance.
"""
export: base_configs.ExportConfig = base_configs.ExportConfig()
runtime: base_configs.RuntimeConfig = base_configs.RuntimeConfig()
......@@ -49,16 +48,15 @@ class EfficientNetImageNetConfig(base_configs.ExperimentConfig):
resume_checkpoint=True,
epochs=500,
steps=None,
callbacks=base_configs.CallbacksConfig(enable_checkpoint_and_export=True,
enable_tensorboard=True),
callbacks=base_configs.CallbacksConfig(
enable_checkpoint_and_export=True, enable_tensorboard=True),
metrics=['accuracy', 'top_5'],
time_history=base_configs.TimeHistoryConfig(log_steps=100),
tensorboard=base_configs.TensorboardConfig(track_lr=True,
write_model_weights=False),
tensorboard=base_configs.TensorboardConfig(
track_lr=True, write_model_weights=False),
set_epoch_loop=False)
evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
epochs_between_evals=1,
steps=None)
epochs_between_evals=1, steps=None)
model: base_configs.ModelConfig = \
efficientnet_config.EfficientNetModelConfig()
......@@ -82,16 +80,15 @@ class ResNetImagenetConfig(base_configs.ExperimentConfig):
resume_checkpoint=True,
epochs=90,
steps=None,
callbacks=base_configs.CallbacksConfig(enable_checkpoint_and_export=True,
enable_tensorboard=True),
callbacks=base_configs.CallbacksConfig(
enable_checkpoint_and_export=True, enable_tensorboard=True),
metrics=['accuracy', 'top_5'],
time_history=base_configs.TimeHistoryConfig(log_steps=100),
tensorboard=base_configs.TensorboardConfig(track_lr=True,
write_model_weights=False),
tensorboard=base_configs.TensorboardConfig(
track_lr=True, write_model_weights=False),
set_epoch_loop=False)
evaluation: base_configs.EvalConfig = base_configs.EvalConfig(
epochs_between_evals=1,
steps=None)
epochs_between_evals=1, steps=None)
model: base_configs.ModelConfig = resnet_config.ResNetModelConfig()
......@@ -109,10 +106,8 @@ def get_config(model: str, dataset: str) -> base_configs.ExperimentConfig:
if dataset not in dataset_model_config_map:
raise KeyError('Invalid dataset received. Received: {}. Supported '
'datasets include: {}'.format(
dataset,
', '.join(dataset_model_config_map.keys())))
dataset, ', '.join(dataset_model_config_map.keys())))
raise KeyError('Invalid model received. Received: {}. Supported models for'
'{} include: {}'.format(
model,
dataset,
model, dataset,
', '.join(dataset_model_config_map[dataset].keys())))
......@@ -21,6 +21,7 @@ from __future__ import print_function
import os
from typing import Any, List, Optional, Tuple, Mapping, Union
from absl import logging
from dataclasses import dataclass
import tensorflow as tf
......@@ -30,7 +31,6 @@ from official.modeling.hyperparams import base_config
from official.vision.image_classification import augment
from official.vision.image_classification import preprocessing
AUGMENTERS = {
'autoaugment': augment.AutoAugment,
'randaugment': augment.RandAugment,
......@@ -42,8 +42,8 @@ class AugmentConfig(base_config.Config):
"""Configuration for image augmenters.
Attributes:
name: The name of the image augmentation to use. Possible options are
None (default), 'autoaugment', or 'randaugment'.
name: The name of the image augmentation to use. Possible options are None
(default), 'autoaugment', or 'randaugment'.
params: Any paramaters used to initialize the augmenter.
"""
name: Optional[str] = None
......@@ -68,17 +68,17 @@ class DatasetConfig(base_config.Config):
'tfds' (load using TFDS), 'records' (load from TFRecords), or 'synthetic'
(generate dummy synthetic data without reading from files).
split: The split of the dataset. Usually 'train', 'validation', or 'test'.
image_size: The size of the image in the dataset. This assumes that
`width` == `height`. Set to 'infer' to infer the image size from TFDS
info. This requires `name` to be a registered dataset in TFDS.
num_classes: The number of classes given by the dataset. Set to 'infer'
to infer the image size from TFDS info. This requires `name` to be a
image_size: The size of the image in the dataset. This assumes that `width`
== `height`. Set to 'infer' to infer the image size from TFDS info. This
requires `name` to be a registered dataset in TFDS.
num_classes: The number of classes given by the dataset. Set to 'infer' to
infer the image size from TFDS info. This requires `name` to be a
registered dataset in TFDS.
num_channels: The number of channels given by the dataset. Set to 'infer'
to infer the image size from TFDS info. This requires `name` to be a
num_channels: The number of channels given by the dataset. Set to 'infer' to
infer the image size from TFDS info. This requires `name` to be a
registered dataset in TFDS.
num_examples: The number of examples given by the dataset. Set to 'infer'
to infer the image size from TFDS info. This requires `name` to be a
num_examples: The number of examples given by the dataset. Set to 'infer' to
infer the image size from TFDS info. This requires `name` to be a
registered dataset in TFDS.
batch_size: The base batch size for the dataset.
use_per_replica_batch_size: Whether to scale the batch size based on
......@@ -284,10 +284,10 @@ class DatasetBuilder:
"""
if strategy:
if strategy.num_replicas_in_sync != self.config.num_devices:
logging.warn('Passed a strategy with %d devices, but expected'
'%d devices.',
strategy.num_replicas_in_sync,
self.config.num_devices)
logging.warn(
'Passed a strategy with %d devices, but expected'
'%d devices.', strategy.num_replicas_in_sync,
self.config.num_devices)
dataset = strategy.experimental_distribute_datasets_from_function(
self._build)
else:
......@@ -295,8 +295,9 @@ class DatasetBuilder:
return dataset
def _build(self, input_context: tf.distribute.InputContext = None
) -> tf.data.Dataset:
def _build(
self,
input_context: tf.distribute.InputContext = None) -> tf.data.Dataset:
"""Construct a dataset end-to-end and return it.
Args:
......@@ -328,8 +329,7 @@ class DatasetBuilder:
logging.info('Using TFDS to load data.')
builder = tfds.builder(self.config.name,
data_dir=self.config.data_dir)
builder = tfds.builder(self.config.name, data_dir=self.config.data_dir)
if self.config.download:
builder.download_and_prepare()
......@@ -380,8 +380,8 @@ class DatasetBuilder:
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(generate_data,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.map(
generate_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
def pipeline(self, dataset: tf.data.Dataset) -> tf.data.Dataset:
......@@ -393,14 +393,14 @@ class DatasetBuilder:
Returns:
A TensorFlow dataset outputting batched images and labels.
"""
if (self.config.builder != 'tfds' and self.input_context
and self.input_context.num_input_pipelines > 1):
if (self.config.builder != 'tfds' and self.input_context and
self.input_context.num_input_pipelines > 1):
dataset = dataset.shard(self.input_context.num_input_pipelines,
self.input_context.input_pipeline_id)
logging.info('Sharding the dataset: input_pipeline_id=%d '
'num_input_pipelines=%d',
self.input_context.num_input_pipelines,
self.input_context.input_pipeline_id)
logging.info(
'Sharding the dataset: input_pipeline_id=%d '
'num_input_pipelines=%d', self.input_context.num_input_pipelines,
self.input_context.input_pipeline_id)
if self.is_training and self.config.builder == 'records':
# Shuffle the input files.
......@@ -429,8 +429,8 @@ class DatasetBuilder:
preprocess = self.parse_record
else:
preprocess = self.preprocess
dataset = dataset.map(preprocess,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.map(
preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE)
if self.input_context and self.config.num_devices > 1:
if not self.config.use_per_replica_batch_size:
......@@ -444,11 +444,11 @@ class DatasetBuilder:
# The batch size of the dataset will be multiplied by the number of
# replicas automatically when strategy.distribute_datasets_from_function
# is called, so we use local batch size here.
dataset = dataset.batch(self.local_batch_size,
drop_remainder=self.is_training)
dataset = dataset.batch(
self.local_batch_size, drop_remainder=self.is_training)
else:
dataset = dataset.batch(self.global_batch_size,
drop_remainder=self.is_training)
dataset = dataset.batch(
self.global_batch_size, drop_remainder=self.is_training)
# Prefetch overlaps in-feed with training
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
......@@ -470,24 +470,15 @@ class DatasetBuilder:
def parse_record(self, record: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
"""Parse an ImageNet record from a serialized string Tensor."""
keys_to_features = {
'image/encoded':
tf.io.FixedLenFeature((), tf.string, ''),
'image/format':
tf.io.FixedLenFeature((), tf.string, 'jpeg'),
'image/class/label':
tf.io.FixedLenFeature([], tf.int64, -1),
'image/class/text':
tf.io.FixedLenFeature([], tf.string, ''),
'image/object/bbox/xmin':
tf.io.VarLenFeature(dtype=tf.float32),
'image/object/bbox/ymin':
tf.io.VarLenFeature(dtype=tf.float32),
'image/object/bbox/xmax':
tf.io.VarLenFeature(dtype=tf.float32),
'image/object/bbox/ymax':
tf.io.VarLenFeature(dtype=tf.float32),
'image/object/class/label':
tf.io.VarLenFeature(dtype=tf.int64),
'image/encoded': tf.io.FixedLenFeature((), tf.string, ''),
'image/format': tf.io.FixedLenFeature((), tf.string, 'jpeg'),
'image/class/label': tf.io.FixedLenFeature([], tf.int64, -1),
'image/class/text': tf.io.FixedLenFeature([], tf.string, ''),
'image/object/bbox/xmin': tf.io.VarLenFeature(dtype=tf.float32),
'image/object/bbox/ymin': tf.io.VarLenFeature(dtype=tf.float32),
'image/object/bbox/xmax': tf.io.VarLenFeature(dtype=tf.float32),
'image/object/bbox/ymax': tf.io.VarLenFeature(dtype=tf.float32),
'image/object/class/label': tf.io.VarLenFeature(dtype=tf.int64),
}
parsed = tf.io.parse_single_example(record, keys_to_features)
......@@ -502,8 +493,8 @@ class DatasetBuilder:
return image, label
def preprocess(self, image: tf.Tensor, label: tf.Tensor
) -> Tuple[tf.Tensor, tf.Tensor]:
def preprocess(self, image: tf.Tensor,
label: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
"""Apply image preprocessing and augmentation to the image and label."""
if self.is_training:
image = preprocessing.preprocess_for_train(
......
......@@ -79,7 +79,7 @@ def get_batch_norm(batch_norm_type: Text) -> tf.keras.layers.BatchNormalization:
Args:
batch_norm_type: The type of batch normalization layer implementation. `tpu`
will use `TpuBatchNormalization`.
will use `TpuBatchNormalization`.
Returns:
An instance of `tf.keras.layers.BatchNormalization`.
......@@ -95,8 +95,10 @@ def count_params(model, trainable_only=True):
if not trainable_only:
return model.count_params()
else:
return int(np.sum([tf.keras.backend.count_params(p)
for p in model.trainable_weights]))
return int(
np.sum([
tf.keras.backend.count_params(p) for p in model.trainable_weights
]))
def load_weights(model: tf.keras.Model,
......@@ -107,8 +109,8 @@ def load_weights(model: tf.keras.Model,
Args:
model: the model to load weights into
model_weights_path: the path of the model weights
weights_format: the model weights format. One of 'saved_model', 'h5',
or 'checkpoint'.
weights_format: the model weights format. One of 'saved_model', 'h5', or
'checkpoint'.
"""
if weights_format == 'saved_model':
loaded_model = tf.keras.models.load_model(model_weights_path)
......
......@@ -64,11 +64,11 @@ class ModelConfig(base_config.Config):
# (input_filters, output_filters, kernel_size, num_repeat,
# expand_ratio, strides, se_ratio)
# pylint: disable=bad-whitespace
BlockConfig.from_args(32, 16, 3, 1, 1, (1, 1), 0.25),
BlockConfig.from_args(16, 24, 3, 2, 6, (2, 2), 0.25),
BlockConfig.from_args(24, 40, 5, 2, 6, (2, 2), 0.25),
BlockConfig.from_args(40, 80, 3, 3, 6, (2, 2), 0.25),
BlockConfig.from_args(80, 112, 5, 3, 6, (1, 1), 0.25),
BlockConfig.from_args(32, 16, 3, 1, 1, (1, 1), 0.25),
BlockConfig.from_args(16, 24, 3, 2, 6, (2, 2), 0.25),
BlockConfig.from_args(24, 40, 5, 2, 6, (2, 2), 0.25),
BlockConfig.from_args(40, 80, 3, 3, 6, (2, 2), 0.25),
BlockConfig.from_args(80, 112, 5, 3, 6, (1, 1), 0.25),
BlockConfig.from_args(112, 192, 5, 4, 6, (2, 2), 0.25),
BlockConfig.from_args(192, 320, 3, 1, 6, (1, 1), 0.25),
# pylint: enable=bad-whitespace
......@@ -128,8 +128,7 @@ DENSE_KERNEL_INITIALIZER = {
}
def round_filters(filters: int,
config: ModelConfig) -> int:
def round_filters(filters: int, config: ModelConfig) -> int:
"""Round number of filters based on width coefficient."""
width_coefficient = config.width_coefficient
min_depth = config.min_depth
......@@ -189,21 +188,24 @@ def conv2d_block(inputs: tf.Tensor,
init_kwargs.update({'depthwise_initializer': CONV_KERNEL_INITIALIZER})
else:
conv2d = tf.keras.layers.Conv2D
init_kwargs.update({'filters': conv_filters,
'kernel_initializer': CONV_KERNEL_INITIALIZER})
init_kwargs.update({
'filters': conv_filters,
'kernel_initializer': CONV_KERNEL_INITIALIZER
})
x = conv2d(**init_kwargs)(inputs)
if use_batch_norm:
bn_axis = 1 if data_format == 'channels_first' else -1
x = batch_norm(axis=bn_axis,
momentum=bn_momentum,
epsilon=bn_epsilon,
name=name + '_bn')(x)
x = batch_norm(
axis=bn_axis,
momentum=bn_momentum,
epsilon=bn_epsilon,
name=name + '_bn')(
x)
if activation is not None:
x = tf.keras.layers.Activation(activation,
name=name + '_activation')(x)
x = tf.keras.layers.Activation(activation, name=name + '_activation')(x)
return x
......@@ -235,42 +237,43 @@ def mb_conv_block(inputs: tf.Tensor,
if block.fused_conv:
# If we use fused mbconv, skip expansion and use regular conv.
x = conv2d_block(x,
filters,
config,
kernel_size=block.kernel_size,
strides=block.strides,
activation=activation,
name=prefix + 'fused')
x = conv2d_block(
x,
filters,
config,
kernel_size=block.kernel_size,
strides=block.strides,
activation=activation,
name=prefix + 'fused')
else:
if block.expand_ratio != 1:
# Expansion phase
kernel_size = (1, 1) if use_depthwise else (3, 3)
x = conv2d_block(x,
filters,
config,
kernel_size=kernel_size,
activation=activation,
name=prefix + 'expand')
x = conv2d_block(
x,
filters,
config,
kernel_size=kernel_size,
activation=activation,
name=prefix + 'expand')
# Depthwise Convolution
if use_depthwise:
x = conv2d_block(x,
conv_filters=None,
config=config,
kernel_size=block.kernel_size,
strides=block.strides,
activation=activation,
depthwise=True,
name=prefix + 'depthwise')
x = conv2d_block(
x,
conv_filters=None,
config=config,
kernel_size=block.kernel_size,
strides=block.strides,
activation=activation,
depthwise=True,
name=prefix + 'depthwise')
# Squeeze and Excitation phase
if use_se:
assert block.se_ratio is not None
assert 0 < block.se_ratio <= 1
num_reduced_filters = max(1, int(
block.input_filters * block.se_ratio
))
num_reduced_filters = max(1, int(block.input_filters * block.se_ratio))
if data_format == 'channels_first':
se_shape = (filters, 1, 1)
......@@ -280,53 +283,51 @@ def mb_conv_block(inputs: tf.Tensor,
se = tf.keras.layers.GlobalAveragePooling2D(name=prefix + 'se_squeeze')(x)
se = tf.keras.layers.Reshape(se_shape, name=prefix + 'se_reshape')(se)
se = conv2d_block(se,
num_reduced_filters,
config,
use_bias=True,
use_batch_norm=False,
activation=activation,
name=prefix + 'se_reduce')
se = conv2d_block(se,
filters,
config,
use_bias=True,
use_batch_norm=False,
activation='sigmoid',
name=prefix + 'se_expand')
se = conv2d_block(
se,
num_reduced_filters,
config,
use_bias=True,
use_batch_norm=False,
activation=activation,
name=prefix + 'se_reduce')
se = conv2d_block(
se,
filters,
config,
use_bias=True,
use_batch_norm=False,
activation='sigmoid',
name=prefix + 'se_expand')
x = tf.keras.layers.multiply([x, se], name=prefix + 'se_excite')
# Output phase
x = conv2d_block(x,
block.output_filters,
config,
activation=None,
name=prefix + 'project')
x = conv2d_block(
x, block.output_filters, config, activation=None, name=prefix + 'project')
# Add identity so that quantization-aware training can insert quantization
# ops correctly.
x = tf.keras.layers.Activation(tf_utils.get_activation('identity'),
name=prefix + 'id')(x)
x = tf.keras.layers.Activation(
tf_utils.get_activation('identity'), name=prefix + 'id')(
x)
if (block.id_skip
and all(s == 1 for s in block.strides)
and block.input_filters == block.output_filters):
if (block.id_skip and all(s == 1 for s in block.strides) and
block.input_filters == block.output_filters):
if drop_connect_rate and drop_connect_rate > 0:
# Apply dropconnect
# The only difference between dropout and dropconnect in TF is scaling by
# drop_connect_rate during training. See:
# https://github.com/keras-team/keras/pull/9898#issuecomment-380577612
x = tf.keras.layers.Dropout(drop_connect_rate,
noise_shape=(None, 1, 1, 1),
name=prefix + 'drop')(x)
x = tf.keras.layers.Dropout(
drop_connect_rate, noise_shape=(None, 1, 1, 1), name=prefix + 'drop')(
x)
x = tf.keras.layers.add([x, inputs], name=prefix + 'add')
return x
def efficientnet(image_input: tf.keras.layers.Input,
config: ModelConfig):
def efficientnet(image_input: tf.keras.layers.Input, config: ModelConfig):
"""Creates an EfficientNet graph given the model parameters.
This function is wrapped by the `EfficientNet` class to make a tf.keras.Model.
......@@ -357,19 +358,18 @@ def efficientnet(image_input: tf.keras.layers.Input,
# Happens on GPU/TPU if available.
x = tf.keras.layers.Permute((3, 1, 2))(x)
if rescale_input:
x = preprocessing.normalize_images(x,
num_channels=input_channels,
dtype=dtype,
data_format=data_format)
x = preprocessing.normalize_images(
x, num_channels=input_channels, dtype=dtype, data_format=data_format)
# Build stem
x = conv2d_block(x,
round_filters(stem_base_filters, config),
config,
kernel_size=[3, 3],
strides=[2, 2],
activation=activation,
name='stem')
x = conv2d_block(
x,
round_filters(stem_base_filters, config),
config,
kernel_size=[3, 3],
strides=[2, 2],
activation=activation,
name='stem')
# Build blocks
num_blocks_total = sum(
......@@ -391,10 +391,7 @@ def efficientnet(image_input: tf.keras.layers.Input,
x = mb_conv_block(x, block, config, block_prefix)
block_num += 1
if block.num_repeat > 1:
block = block.replace(
input_filters=block.output_filters,
strides=[1, 1]
)
block = block.replace(input_filters=block.output_filters, strides=[1, 1])
for block_idx in range(block.num_repeat - 1):
drop_rate = drop_connect_rate * float(block_num) / num_blocks_total
......@@ -404,11 +401,12 @@ def efficientnet(image_input: tf.keras.layers.Input,
block_num += 1
# Build top
x = conv2d_block(x,
round_filters(top_base_filters, config),
config,
activation=activation,
name='top')
x = conv2d_block(
x,
round_filters(top_base_filters, config),
config,
activation=activation,
name='top')
# Build classifier
x = tf.keras.layers.GlobalAveragePooling2D(name='top_pool')(x)
......@@ -419,7 +417,8 @@ def efficientnet(image_input: tf.keras.layers.Input,
kernel_initializer=DENSE_KERNEL_INITIALIZER,
kernel_regularizer=tf.keras.regularizers.l2(weight_decay),
bias_regularizer=tf.keras.regularizers.l2(weight_decay),
name='logits')(x)
name='logits')(
x)
x = tf.keras.layers.Activation('softmax', name='probs')(x)
return x
......@@ -439,8 +438,7 @@ class EfficientNet(tf.keras.Model):
Args:
config: (optional) the main model parameters to create the model
overrides: (optional) a dict containing keys that can override
config
overrides: (optional) a dict containing keys that can override config
"""
overrides = overrides or {}
config = config or ModelConfig()
......@@ -457,9 +455,7 @@ class EfficientNet(tf.keras.Model):
# Cast to float32 in case we have a different model dtype
output = tf.cast(output, tf.float32)
logging.info('Building model %s with params %s',
model_name,
self.config)
logging.info('Building model %s with params %s', model_name, self.config)
super(EfficientNet, self).__init__(
inputs=image_input, outputs=output, name=model_name)
......@@ -477,8 +473,8 @@ class EfficientNet(tf.keras.Model):
Args:
model_name: the predefined model name
model_weights_path: the path to the weights (h5 file or saved model dir)
weights_format: the model weights format. One of 'saved_model', 'h5',
or 'checkpoint'.
weights_format: the model weights format. One of 'saved_model', 'h5', or
'checkpoint'.
overrides: (optional) a dict containing keys that can override config
Returns:
......@@ -498,8 +494,7 @@ class EfficientNet(tf.keras.Model):
model = cls(config=config, overrides=overrides)
if model_weights_path:
common_modules.load_weights(model,
model_weights_path,
weights_format=weights_format)
common_modules.load_weights(
model, model_weights_path, weights_format=weights_format)
return model
......@@ -30,10 +30,8 @@ from official.vision.image_classification.efficientnet import efficientnet_model
FLAGS = flags.FLAGS
flags.DEFINE_string("model_name", None,
"EfficientNet model name.")
flags.DEFINE_string("model_path", None,
"File path to TF model checkpoint.")
flags.DEFINE_string("model_name", None, "EfficientNet model name.")
flags.DEFINE_string("model_path", None, "File path to TF model checkpoint.")
flags.DEFINE_string("export_path", None,
"TF-Hub SavedModel destination path to export.")
......@@ -65,5 +63,6 @@ def main(argv):
export_tfhub(FLAGS.model_path, FLAGS.export_path, FLAGS.model_name)
if __name__ == "__main__":
app.run(main)
......@@ -29,11 +29,10 @@ BASE_LEARNING_RATE = 0.1
class WarmupDecaySchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
"""A wrapper for LearningRateSchedule that includes warmup steps."""
def __init__(
self,
lr_schedule: tf.keras.optimizers.schedules.LearningRateSchedule,
warmup_steps: int,
warmup_lr: Optional[float] = None):
def __init__(self,
lr_schedule: tf.keras.optimizers.schedules.LearningRateSchedule,
warmup_steps: int,
warmup_lr: Optional[float] = None):
"""Add warmup decay to a learning rate schedule.
Args:
......@@ -42,7 +41,6 @@ class WarmupDecaySchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
warmup_lr: an optional field for the final warmup learning rate. This
should be provided if the base `lr_schedule` does not contain this
field.
"""
super(WarmupDecaySchedule, self).__init__()
self._lr_schedule = lr_schedule
......@@ -63,8 +61,7 @@ class WarmupDecaySchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
global_step_recomp = tf.cast(step, dtype)
warmup_steps = tf.cast(self._warmup_steps, dtype)
warmup_lr = initial_learning_rate * global_step_recomp / warmup_steps
lr = tf.cond(global_step_recomp < warmup_steps,
lambda: warmup_lr,
lr = tf.cond(global_step_recomp < warmup_steps, lambda: warmup_lr,
lambda: lr)
return lr
......
......@@ -37,14 +37,13 @@ class LearningRateTests(tf.test.TestCase):
decay_steps=decay_steps,
decay_rate=decay_rate)
lr = learning_rate.WarmupDecaySchedule(
lr_schedule=base_lr,
warmup_steps=warmup_steps)
lr_schedule=base_lr, warmup_steps=warmup_steps)
for step in range(warmup_steps - 1):
config = lr.get_config()
self.assertEqual(config['warmup_steps'], warmup_steps)
self.assertAllClose(self.evaluate(lr(step)),
step / warmup_steps * initial_lr)
self.assertAllClose(
self.evaluate(lr(step)), step / warmup_steps * initial_lr)
def test_cosine_decay_with_warmup(self):
"""Basic computational test for cosine decay with warmup."""
......
......@@ -19,6 +19,7 @@ from __future__ import print_function
import os
# Import libraries
from absl import app
from absl import flags
from absl import logging
......
......@@ -58,7 +58,8 @@ class KerasMnistTest(tf.test.TestCase, parameterized.TestCase):
"""Test Keras MNIST model with `strategy`."""
extra_flags = [
"-train_epochs", "1",
"-train_epochs",
"1",
# Let TFDS find the metadata folder automatically
"--data_dir="
]
......@@ -72,9 +73,10 @@ class KerasMnistTest(tf.test.TestCase, parameterized.TestCase):
tf.data.Dataset.from_tensor_slices(dummy_data),
)
run = functools.partial(mnist_main.run,
datasets_override=datasets,
strategy_override=distribution)
run = functools.partial(
mnist_main.run,
datasets_override=datasets,
strategy_override=distribution)
integration.run_synthetic(
main=run,
......
......@@ -65,19 +65,19 @@ class MovingAverage(tf.keras.optimizers.Optimizer):
"""Construct a new MovingAverage optimizer.
Args:
optimizer: `tf.keras.optimizers.Optimizer` that will be
used to compute and apply gradients.
average_decay: float. Decay to use to maintain the moving averages
of trained variables.
optimizer: `tf.keras.optimizers.Optimizer` that will be used to compute
and apply gradients.
average_decay: float. Decay to use to maintain the moving averages of
trained variables.
start_step: int. What step to start the moving average.
dynamic_decay: bool. Whether to change the decay based on the number
of optimizer updates. Decay will start at 0.1 and gradually increase
up to `average_decay` after each optimizer update. This behavior is
similar to `tf.train.ExponentialMovingAverage` in TF 1.x.
name: Optional name for the operations created when applying
gradients. Defaults to "moving_average".
**kwargs: keyword arguments. Allowed to be {`clipnorm`,
`clipvalue`, `lr`, `decay`}.
dynamic_decay: bool. Whether to change the decay based on the number of
optimizer updates. Decay will start at 0.1 and gradually increase up to
`average_decay` after each optimizer update. This behavior is similar to
`tf.train.ExponentialMovingAverage` in TF 1.x.
name: Optional name for the operations created when applying gradients.
Defaults to "moving_average".
**kwargs: keyword arguments. Allowed to be {`clipnorm`, `clipvalue`, `lr`,
`decay`}.
"""
super(MovingAverage, self).__init__(name, **kwargs)
self._optimizer = optimizer
......@@ -128,8 +128,8 @@ class MovingAverage(tf.keras.optimizers.Optimizer):
strategy.extended.update(v_moving, _apply_moving, args=(v_normal,))
ctx = tf.distribute.get_replica_context()
return ctx.merge_call(_update, args=(zip(self._average_weights,
self._model_weights),))
return ctx.merge_call(
_update, args=(zip(self._average_weights, self._model_weights),))
def swap_weights(self):
"""Swap the average and moving weights.
......@@ -148,12 +148,15 @@ class MovingAverage(tf.keras.optimizers.Optimizer):
@tf.function
def _swap_weights(self):
def fn_0(a, b):
a.assign_add(b)
return a
def fn_1(b, a):
b.assign(a - b)
return b
def fn_2(a, b):
a.assign_sub(b)
return a
......@@ -174,12 +177,14 @@ class MovingAverage(tf.keras.optimizers.Optimizer):
Args:
var_list: List of model variables to be assigned to their average.
Returns:
assign_op: The op corresponding to the assignment operation of
variables to their average.
"""
assign_op = tf.group([
var.assign(self.get_slot(var, 'average')) for var in var_list
var.assign(self.get_slot(var, 'average'))
for var in var_list
if var.trainable
])
return assign_op
......@@ -256,13 +261,13 @@ def build_optimizer(
"""Build the optimizer based on name.
Args:
optimizer_name: String representation of the optimizer name. Examples:
sgd, momentum, rmsprop.
optimizer_name: String representation of the optimizer name. Examples: sgd,
momentum, rmsprop.
base_learning_rate: `tf.keras.optimizers.schedules.LearningRateSchedule`
base learning rate.
params: String -> Any dictionary representing the optimizer params.
This should contain optimizer specific parameters such as
`base_learning_rate`, `decay`, etc.
params: String -> Any dictionary representing the optimizer params. This
should contain optimizer specific parameters such as `base_learning_rate`,
`decay`, etc.
model: The `tf.keras.Model`. This is used for the shadow copy if using
`MovingAverage`.
......@@ -279,43 +284,47 @@ def build_optimizer(
if optimizer_name == 'sgd':
logging.info('Using SGD optimizer')
nesterov = params.get('nesterov', False)
optimizer = tf.keras.optimizers.SGD(learning_rate=base_learning_rate,
nesterov=nesterov)
optimizer = tf.keras.optimizers.SGD(
learning_rate=base_learning_rate, nesterov=nesterov)
elif optimizer_name == 'momentum':
logging.info('Using momentum optimizer')
nesterov = params.get('nesterov', False)
optimizer = tf.keras.optimizers.SGD(learning_rate=base_learning_rate,
momentum=params['momentum'],
nesterov=nesterov)
optimizer = tf.keras.optimizers.SGD(
learning_rate=base_learning_rate,
momentum=params['momentum'],
nesterov=nesterov)
elif optimizer_name == 'rmsprop':
logging.info('Using RMSProp')
rho = params.get('decay', None) or params.get('rho', 0.9)
momentum = params.get('momentum', 0.9)
epsilon = params.get('epsilon', 1e-07)
optimizer = tf.keras.optimizers.RMSprop(learning_rate=base_learning_rate,
rho=rho,
momentum=momentum,
epsilon=epsilon)
optimizer = tf.keras.optimizers.RMSprop(
learning_rate=base_learning_rate,
rho=rho,
momentum=momentum,
epsilon=epsilon)
elif optimizer_name == 'adam':
logging.info('Using Adam')
beta_1 = params.get('beta_1', 0.9)
beta_2 = params.get('beta_2', 0.999)
epsilon = params.get('epsilon', 1e-07)
optimizer = tf.keras.optimizers.Adam(learning_rate=base_learning_rate,
beta_1=beta_1,
beta_2=beta_2,
epsilon=epsilon)
optimizer = tf.keras.optimizers.Adam(
learning_rate=base_learning_rate,
beta_1=beta_1,
beta_2=beta_2,
epsilon=epsilon)
elif optimizer_name == 'adamw':
logging.info('Using AdamW')
weight_decay = params.get('weight_decay', 0.01)
beta_1 = params.get('beta_1', 0.9)
beta_2 = params.get('beta_2', 0.999)
epsilon = params.get('epsilon', 1e-07)
optimizer = tfa.optimizers.AdamW(weight_decay=weight_decay,
learning_rate=base_learning_rate,
beta_1=beta_1,
beta_2=beta_2,
epsilon=epsilon)
optimizer = tfa.optimizers.AdamW(
weight_decay=weight_decay,
learning_rate=base_learning_rate,
beta_1=beta_1,
beta_2=beta_2,
epsilon=epsilon)
else:
raise ValueError('Unknown optimizer %s' % optimizer_name)
......@@ -330,8 +339,7 @@ def build_optimizer(
raise ValueError('`model` must be provided if using `MovingAverage`.')
logging.info('Including moving average decay.')
optimizer = MovingAverage(
optimizer=optimizer,
average_decay=moving_average_decay)
optimizer=optimizer, average_decay=moving_average_decay)
optimizer.shadow_copy(model)
return optimizer
......@@ -358,13 +366,15 @@ def build_learning_rate(params: base_configs.LearningRateConfig,
if lr_multiplier and lr_multiplier > 0:
# Scale the learning rate based on the batch size and a multiplier
base_lr *= lr_multiplier * batch_size
logging.info('Scaling the learning rate based on the batch size '
'multiplier. New base_lr: %f', base_lr)
logging.info(
'Scaling the learning rate based on the batch size '
'multiplier. New base_lr: %f', base_lr)
if decay_type == 'exponential':
logging.info('Using exponential learning rate with: '
'initial_learning_rate: %f, decay_steps: %d, '
'decay_rate: %f', base_lr, decay_steps, decay_rate)
logging.info(
'Using exponential learning rate with: '
'initial_learning_rate: %f, decay_steps: %d, '
'decay_rate: %f', base_lr, decay_steps, decay_rate)
lr = tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate=base_lr,
decay_steps=decay_steps,
......@@ -374,12 +384,11 @@ def build_learning_rate(params: base_configs.LearningRateConfig,
steps_per_epoch = params.examples_per_epoch // batch_size
boundaries = [boundary * steps_per_epoch for boundary in params.boundaries]
multipliers = [batch_size * multiplier for multiplier in params.multipliers]
logging.info('Using stepwise learning rate. Parameters: '
'boundaries: %s, values: %s',
boundaries, multipliers)
logging.info(
'Using stepwise learning rate. Parameters: '
'boundaries: %s, values: %s', boundaries, multipliers)
lr = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
boundaries=boundaries,
values=multipliers)
boundaries=boundaries, values=multipliers)
elif decay_type == 'cosine_with_warmup':
lr = learning_rate.CosineDecayWithWarmup(
batch_size=batch_size,
......@@ -389,7 +398,6 @@ def build_learning_rate(params: base_configs.LearningRateConfig,
if decay_type not in ['cosine_with_warmup']:
logging.info('Applying %d warmup steps to the learning rate',
warmup_steps)
lr = learning_rate.WarmupDecaySchedule(lr,
warmup_steps,
warmup_lr=base_lr)
lr = learning_rate.WarmupDecaySchedule(
lr, warmup_steps, warmup_lr=base_lr)
return lr
......@@ -35,10 +35,8 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
return model
@parameterized.named_parameters(
('sgd', 'sgd', 0., False),
('momentum', 'momentum', 0., False),
('rmsprop', 'rmsprop', 0., False),
('adam', 'adam', 0., False),
('sgd', 'sgd', 0., False), ('momentum', 'momentum', 0., False),
('rmsprop', 'rmsprop', 0., False), ('adam', 'adam', 0., False),
('adamw', 'adamw', 0., False),
('momentum_lookahead', 'momentum', 0., True),
('sgd_ema', 'sgd', 0.999, False),
......@@ -84,16 +82,13 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
train_steps = 1
lr = optimizer_factory.build_learning_rate(
params=params,
batch_size=batch_size,
train_steps=train_steps)
params=params, batch_size=batch_size, train_steps=train_steps)
self.assertTrue(
issubclass(
type(lr), tf.keras.optimizers.schedules.LearningRateSchedule))
@parameterized.named_parameters(
('exponential', 'exponential'),
('cosine_with_warmup', 'cosine_with_warmup'))
@parameterized.named_parameters(('exponential', 'exponential'),
('cosine_with_warmup', 'cosine_with_warmup'))
def test_learning_rate_with_decay_and_warmup(self, lr_decay_type):
"""Basic smoke test for syntax."""
params = base_configs.LearningRateConfig(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment