Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
88253ce5
Commit
88253ce5
authored
Aug 12, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Aug 12, 2020
Browse files
Internal change
PiperOrigin-RevId: 326286926
parent
52371ffe
Changes
205
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
533 additions
and
612 deletions
+533
-612
official/vision/detection/utils/object_detection/shape_utils.py
...al/vision/detection/utils/object_detection/shape_utils.py
+4
-9
official/vision/detection/utils/object_detection/target_assigner.py
...ision/detection/utils/object_detection/target_assigner.py
+46
-43
official/vision/detection/utils/object_detection/visualization_utils.py
...n/detection/utils/object_detection/visualization_utils.py
+73
-82
official/vision/image_classification/augment.py
official/vision/image_classification/augment.py
+48
-61
official/vision/image_classification/augment_test.py
official/vision/image_classification/augment_test.py
+9
-22
official/vision/image_classification/callbacks.py
official/vision/image_classification/callbacks.py
+32
-42
official/vision/image_classification/classifier_trainer.py
official/vision/image_classification/classifier_trainer.py
+42
-46
official/vision/image_classification/classifier_trainer_test.py
...al/vision/image_classification/classifier_trainer_test.py
+37
-48
official/vision/image_classification/configs/base_configs.py
official/vision/image_classification/configs/base_configs.py
+0
-1
official/vision/image_classification/configs/configs.py
official/vision/image_classification/configs/configs.py
+12
-17
official/vision/image_classification/dataset_factory.py
official/vision/image_classification/dataset_factory.py
+45
-54
official/vision/image_classification/efficientnet/common_modules.py
...ision/image_classification/efficientnet/common_modules.py
+7
-5
official/vision/image_classification/efficientnet/efficientnet_model.py
...n/image_classification/efficientnet/efficientnet_model.py
+94
-99
official/vision/image_classification/efficientnet/tfhub_export.py
.../vision/image_classification/efficientnet/tfhub_export.py
+3
-4
official/vision/image_classification/learning_rate.py
official/vision/image_classification/learning_rate.py
+5
-8
official/vision/image_classification/learning_rate_test.py
official/vision/image_classification/learning_rate_test.py
+3
-4
official/vision/image_classification/mnist_main.py
official/vision/image_classification/mnist_main.py
+1
-0
official/vision/image_classification/mnist_test.py
official/vision/image_classification/mnist_test.py
+6
-4
official/vision/image_classification/optimizer_factory.py
official/vision/image_classification/optimizer_factory.py
+61
-53
official/vision/image_classification/optimizer_factory_test.py
...ial/vision/image_classification/optimizer_factory_test.py
+5
-10
No files found.
official/vision/detection/utils/object_detection/shape_utils.py
View file @
88253ce5
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Utils used to manipulate tensor shapes."""
"""Utils used to manipulate tensor shapes."""
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -42,7 +41,8 @@ def assert_shape_equal(shape_a, shape_b):
...
@@ -42,7 +41,8 @@ def assert_shape_equal(shape_a, shape_b):
all
(
isinstance
(
dim
,
int
)
for
dim
in
shape_b
)):
all
(
isinstance
(
dim
,
int
)
for
dim
in
shape_b
)):
if
shape_a
!=
shape_b
:
if
shape_a
!=
shape_b
:
raise
ValueError
(
'Unequal shapes {}, {}'
.
format
(
shape_a
,
shape_b
))
raise
ValueError
(
'Unequal shapes {}, {}'
.
format
(
shape_a
,
shape_b
))
else
:
return
tf
.
no_op
()
else
:
return
tf
.
no_op
()
else
:
else
:
return
tf
.
assert_equal
(
shape_a
,
shape_b
)
return
tf
.
assert_equal
(
shape_a
,
shape_b
)
...
@@ -87,9 +87,7 @@ def pad_or_clip_nd(tensor, output_shape):
...
@@ -87,9 +87,7 @@ def pad_or_clip_nd(tensor, output_shape):
if
shape
is
not
None
else
-
1
for
i
,
shape
in
enumerate
(
output_shape
)
if
shape
is
not
None
else
-
1
for
i
,
shape
in
enumerate
(
output_shape
)
]
]
clipped_tensor
=
tf
.
slice
(
clipped_tensor
=
tf
.
slice
(
tensor
,
tensor
,
begin
=
tf
.
zeros
(
len
(
clip_size
),
dtype
=
tf
.
int32
),
size
=
clip_size
)
begin
=
tf
.
zeros
(
len
(
clip_size
),
dtype
=
tf
.
int32
),
size
=
clip_size
)
# Pad tensor if the shape of clipped tensor is smaller than the expected
# Pad tensor if the shape of clipped tensor is smaller than the expected
# shape.
# shape.
...
@@ -99,10 +97,7 @@ def pad_or_clip_nd(tensor, output_shape):
...
@@ -99,10 +97,7 @@ def pad_or_clip_nd(tensor, output_shape):
for
i
,
shape
in
enumerate
(
output_shape
)
for
i
,
shape
in
enumerate
(
output_shape
)
]
]
paddings
=
tf
.
stack
(
paddings
=
tf
.
stack
(
[
[
tf
.
zeros
(
len
(
trailing_paddings
),
dtype
=
tf
.
int32
),
trailing_paddings
],
tf
.
zeros
(
len
(
trailing_paddings
),
dtype
=
tf
.
int32
),
trailing_paddings
],
axis
=
1
)
axis
=
1
)
padded_tensor
=
tf
.
pad
(
tensor
=
clipped_tensor
,
paddings
=
paddings
)
padded_tensor
=
tf
.
pad
(
tensor
=
clipped_tensor
,
paddings
=
paddings
)
output_static_shape
=
[
output_static_shape
=
[
...
...
official/vision/detection/utils/object_detection/target_assigner.py
View file @
88253ce5
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Base target assigner module.
"""Base target assigner module.
The job of a TargetAssigner is, for a given set of anchors (bounding boxes) and
The job of a TargetAssigner is, for a given set of anchors (bounding boxes) and
...
@@ -31,35 +30,39 @@ Note that TargetAssigners only operate on detections from a single
...
@@ -31,35 +30,39 @@ Note that TargetAssigners only operate on detections from a single
image at a time, so any logic for applying a TargetAssigner to multiple
image at a time, so any logic for applying a TargetAssigner to multiple
images must be handled externally.
images must be handled externally.
"""
"""
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.vision.detection.utils.object_detection
import
box_list
from
official.vision.detection.utils.object_detection
import
box_list
from
official.vision.detection.utils.object_detection
import
shape_utils
from
official.vision.detection.utils.object_detection
import
shape_utils
KEYPOINTS_FIELD_NAME
=
'keypoints'
KEYPOINTS_FIELD_NAME
=
'keypoints'
class
TargetAssigner
(
object
):
class
TargetAssigner
(
object
):
"""Target assigner to compute classification and regression targets."""
"""Target assigner to compute classification and regression targets."""
def
__init__
(
self
,
similarity_calc
,
matcher
,
box_coder
,
def
__init__
(
self
,
negative_class_weight
=
1.0
,
unmatched_cls_target
=
None
):
similarity_calc
,
matcher
,
box_coder
,
negative_class_weight
=
1.0
,
unmatched_cls_target
=
None
):
"""Construct Object Detection Target Assigner.
"""Construct Object Detection Target Assigner.
Args:
Args:
similarity_calc: a RegionSimilarityCalculator
similarity_calc: a RegionSimilarityCalculator
matcher: Matcher used to match groundtruth to anchors.
matcher: Matcher used to match groundtruth to anchors.
box_coder: BoxCoder used to encode matching groundtruth boxes with
box_coder: BoxCoder used to encode matching groundtruth boxes with
respect
respect
to anchors.
to anchors.
negative_class_weight: classification weight to be associated to negative
negative_class_weight: classification weight to be associated to negative
anchors (default: 1.0). The weight must be in [0., 1.].
anchors (default: 1.0). The weight must be in [0., 1.].
unmatched_cls_target: a float32 tensor with shape [d_1, d_2, ..., d_k]
unmatched_cls_target: a float32 tensor with shape [d_1, d_2, ..., d_k]
which is consistent with the classification target for each
which is consistent with the classification target for each
anchor (and
anchor (and
can be empty for scalar targets). This shape must thus be
can be empty for scalar targets). This shape must thus be
compatible
compatible
with the groundtruth labels that are passed to the "assign"
with the groundtruth labels that are passed to the "assign"
function
function
(which have shape [num_gt_boxes, d_1, d_2, ..., d_k]).
(which have shape [num_gt_boxes, d_1, d_2, ..., d_k]).
If set to None,
If set to None,
unmatched_cls_target is set to be [0] for each anchor.
unmatched_cls_target is set to be [0] for each anchor.
Raises:
Raises:
ValueError: if similarity_calc is not a RegionSimilarityCalculator or
ValueError: if similarity_calc is not a RegionSimilarityCalculator or
...
@@ -78,8 +81,12 @@ class TargetAssigner(object):
...
@@ -78,8 +81,12 @@ class TargetAssigner(object):
def
box_coder
(
self
):
def
box_coder
(
self
):
return
self
.
_box_coder
return
self
.
_box_coder
def
assign
(
self
,
anchors
,
groundtruth_boxes
,
groundtruth_labels
=
None
,
def
assign
(
self
,
groundtruth_weights
=
None
,
**
params
):
anchors
,
groundtruth_boxes
,
groundtruth_labels
=
None
,
groundtruth_weights
=
None
,
**
params
):
"""Assign classification and regression targets to each anchor.
"""Assign classification and regression targets to each anchor.
For a given set of anchors and groundtruth detections, match anchors
For a given set of anchors and groundtruth detections, match anchors
...
@@ -93,16 +100,16 @@ class TargetAssigner(object):
...
@@ -93,16 +100,16 @@ class TargetAssigner(object):
Args:
Args:
anchors: a BoxList representing N anchors
anchors: a BoxList representing N anchors
groundtruth_boxes: a BoxList representing M groundtruth boxes
groundtruth_boxes: a BoxList representing M groundtruth boxes
groundtruth_labels: a tensor of shape [M, d_1, ... d_k]
groundtruth_labels: a tensor of shape [M, d_1, ... d_k]
with labels for
with labels for
each of the ground_truth boxes. The subshape
each of the ground_truth boxes. The subshape
[d_1, ... d_k] can be empty
[d_1, ... d_k] can be empty
(corresponding to scalar inputs). When set
(corresponding to scalar inputs). When set
to None, groundtruth_labels
to None, groundtruth_labels
assumes a binary problem where all
assumes a binary problem where all
ground_truth boxes get a positive
ground_truth boxes get a positive
label (of 1).
label (of 1).
groundtruth_weights: a float tensor of shape [M] indicating the weight to
groundtruth_weights: a float tensor of shape [M] indicating the weight to
assign to all anchors match to a particular groundtruth box. The weights
assign to all anchors match to a particular groundtruth box. The weights
must be in [0., 1.]. If None, all weights are set to 1.
must be in [0., 1.]. If None, all weights are set to 1.
**params: Additional keyword arguments for specific implementations of
**params: Additional keyword arguments for specific implementations of
the
the
Matcher.
Matcher.
Returns:
Returns:
cls_targets: a float32 tensor with shape [num_anchors, d_1, d_2 ... d_k],
cls_targets: a float32 tensor with shape [num_anchors, d_1, d_2 ... d_k],
...
@@ -125,16 +132,15 @@ class TargetAssigner(object):
...
@@ -125,16 +132,15 @@ class TargetAssigner(object):
raise
ValueError
(
'groundtruth_boxes must be an BoxList'
)
raise
ValueError
(
'groundtruth_boxes must be an BoxList'
)
if
groundtruth_labels
is
None
:
if
groundtruth_labels
is
None
:
groundtruth_labels
=
tf
.
ones
(
tf
.
expand_dims
(
groundtruth_boxes
.
num_boxes
(),
groundtruth_labels
=
tf
.
ones
(
0
))
tf
.
expand_dims
(
groundtruth_boxes
.
num_boxes
(),
0
))
groundtruth_labels
=
tf
.
expand_dims
(
groundtruth_labels
,
-
1
)
groundtruth_labels
=
tf
.
expand_dims
(
groundtruth_labels
,
-
1
)
unmatched_shape_assert
=
shape_utils
.
assert_shape_equal
(
unmatched_shape_assert
=
shape_utils
.
assert_shape_equal
(
shape_utils
.
combined_static_and_dynamic_shape
(
groundtruth_labels
)[
1
:],
shape_utils
.
combined_static_and_dynamic_shape
(
groundtruth_labels
)[
1
:],
shape_utils
.
combined_static_and_dynamic_shape
(
shape_utils
.
combined_static_and_dynamic_shape
(
self
.
_unmatched_cls_target
))
self
.
_unmatched_cls_target
))
labels_and_box_shapes_assert
=
shape_utils
.
assert_shape_equal
(
labels_and_box_shapes_assert
=
shape_utils
.
assert_shape_equal
(
shape_utils
.
combined_static_and_dynamic_shape
(
shape_utils
.
combined_static_and_dynamic_shape
(
groundtruth_labels
)[:
1
],
groundtruth_labels
)[:
1
],
shape_utils
.
combined_static_and_dynamic_shape
(
shape_utils
.
combined_static_and_dynamic_shape
(
groundtruth_boxes
.
get
())[:
1
])
groundtruth_boxes
.
get
())[:
1
])
...
@@ -145,11 +151,10 @@ class TargetAssigner(object):
...
@@ -145,11 +151,10 @@ class TargetAssigner(object):
groundtruth_weights
=
tf
.
ones
([
num_gt_boxes
],
dtype
=
tf
.
float32
)
groundtruth_weights
=
tf
.
ones
([
num_gt_boxes
],
dtype
=
tf
.
float32
)
with
tf
.
control_dependencies
(
with
tf
.
control_dependencies
(
[
unmatched_shape_assert
,
labels_and_box_shapes_assert
]):
[
unmatched_shape_assert
,
labels_and_box_shapes_assert
]):
match_quality_matrix
=
self
.
_similarity_calc
.
compare
(
groundtruth_boxes
,
match_quality_matrix
=
self
.
_similarity_calc
.
compare
(
anchors
)
groundtruth_boxes
,
anchors
)
match
=
self
.
_matcher
.
match
(
match_quality_matrix
,
**
params
)
match
=
self
.
_matcher
.
match
(
match_quality_matrix
,
**
params
)
reg_targets
=
self
.
_create_regression_targets
(
anchors
,
reg_targets
=
self
.
_create_regression_targets
(
anchors
,
groundtruth_boxes
,
groundtruth_boxes
,
match
)
match
)
cls_targets
=
self
.
_create_classification_targets
(
groundtruth_labels
,
cls_targets
=
self
.
_create_classification_targets
(
groundtruth_labels
,
match
)
match
)
...
@@ -210,8 +215,8 @@ class TargetAssigner(object):
...
@@ -210,8 +215,8 @@ class TargetAssigner(object):
match
.
match_results
)
match
.
match_results
)
# Zero out the unmatched and ignored regression targets.
# Zero out the unmatched and ignored regression targets.
unmatched_ignored_reg_targets
=
tf
.
tile
(
unmatched_ignored_reg_targets
=
tf
.
tile
(
self
.
_default_regression_target
(),
self
.
_default_regression_target
(),
[
match_results_shape
[
0
],
1
])
[
match_results_shape
[
0
],
1
])
matched_anchors_mask
=
match
.
matched_column_indicator
()
matched_anchors_mask
=
match
.
matched_column_indicator
()
# To broadcast matched_anchors_mask to the same shape as
# To broadcast matched_anchors_mask to the same shape as
# matched_reg_targets.
# matched_reg_targets.
...
@@ -233,7 +238,7 @@ class TargetAssigner(object):
...
@@ -233,7 +238,7 @@ class TargetAssigner(object):
Returns:
Returns:
default_target: a float32 tensor with shape [1, box_code_dimension]
default_target: a float32 tensor with shape [1, box_code_dimension]
"""
"""
return
tf
.
constant
([
self
.
_box_coder
.
code_size
*
[
0
]],
tf
.
float32
)
return
tf
.
constant
([
self
.
_box_coder
.
code_size
*
[
0
]],
tf
.
float32
)
def
_create_classification_targets
(
self
,
groundtruth_labels
,
match
):
def
_create_classification_targets
(
self
,
groundtruth_labels
,
match
):
"""Create classification targets for each anchor.
"""Create classification targets for each anchor.
...
@@ -243,11 +248,11 @@ class TargetAssigner(object):
...
@@ -243,11 +248,11 @@ class TargetAssigner(object):
to anything are given the target self._unmatched_cls_target
to anything are given the target self._unmatched_cls_target
Args:
Args:
groundtruth_labels: a tensor of shape [num_gt_boxes, d_1, ... d_k]
groundtruth_labels: a tensor of shape [num_gt_boxes, d_1, ... d_k]
with
with
labels for each of the ground_truth boxes. The subshape
labels for each of the ground_truth boxes. The subshape
[d_1, ... d_k]
[d_1, ... d_k]
can be empty (corresponding to scalar labels).
can be empty (corresponding to scalar labels).
match: a matcher.Match object that provides a matching between anchors
match: a matcher.Match object that provides a matching between anchors
and
and
groundtruth boxes.
groundtruth boxes.
Returns:
Returns:
a float32 tensor with shape [num_anchors, d_1, d_2 ... d_k], where the
a float32 tensor with shape [num_anchors, d_1, d_2 ... d_k], where the
...
@@ -267,8 +272,8 @@ class TargetAssigner(object):
...
@@ -267,8 +272,8 @@ class TargetAssigner(object):
negative anchor.
negative anchor.
Args:
Args:
match: a matcher.Match object that provides a matching between anchors
match: a matcher.Match object that provides a matching between anchors
and
and
groundtruth boxes.
groundtruth boxes.
groundtruth_weights: a float tensor of shape [M] indicating the weight to
groundtruth_weights: a float tensor of shape [M] indicating the weight to
assign to all anchors match to a particular groundtruth box.
assign to all anchors match to a particular groundtruth box.
...
@@ -278,9 +283,7 @@ class TargetAssigner(object):
...
@@ -278,9 +283,7 @@ class TargetAssigner(object):
return
match
.
gather_based_on_match
(
return
match
.
gather_based_on_match
(
groundtruth_weights
,
ignored_value
=
0.
,
unmatched_value
=
0.
)
groundtruth_weights
,
ignored_value
=
0.
,
unmatched_value
=
0.
)
def
_create_classification_weights
(
self
,
def
_create_classification_weights
(
self
,
match
,
groundtruth_weights
):
match
,
groundtruth_weights
):
"""Create classification weights for each anchor.
"""Create classification weights for each anchor.
Positive (matched) anchors are associated with a weight of
Positive (matched) anchors are associated with a weight of
...
@@ -291,8 +294,8 @@ class TargetAssigner(object):
...
@@ -291,8 +294,8 @@ class TargetAssigner(object):
the case in object detection).
the case in object detection).
Args:
Args:
match: a matcher.Match object that provides a matching between anchors
match: a matcher.Match object that provides a matching between anchors
and
and
groundtruth boxes.
groundtruth boxes.
groundtruth_weights: a float tensor of shape [M] indicating the weight to
groundtruth_weights: a float tensor of shape [M] indicating the weight to
assign to all anchors match to a particular groundtruth box.
assign to all anchors match to a particular groundtruth box.
...
...
official/vision/detection/utils/object_detection/visualization_utils.py
View file @
88253ce5
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""A set of functions that are used for visualization.
"""A set of functions that are used for visualization.
These functions often receive an image, perform some visualization on the image.
These functions often receive an image, perform some visualization on the image.
...
@@ -21,9 +20,11 @@ The functions do not return a value, instead they modify the image itself.
...
@@ -21,9 +20,11 @@ The functions do not return a value, instead they modify the image itself.
"""
"""
import
collections
import
collections
import
functools
import
functools
from
absl
import
logging
from
absl
import
logging
# Set headless-friendly backend.
# Set headless-friendly backend.
import
matplotlib
;
matplotlib
.
use
(
'Agg'
)
# pylint: disable=multiple-statements
import
matplotlib
matplotlib
.
use
(
'Agg'
)
# pylint: disable=multiple-statements
import
matplotlib.pyplot
as
plt
# pylint: disable=g-import-not-at-top
import
matplotlib.pyplot
as
plt
# pylint: disable=g-import-not-at-top
import
numpy
as
np
import
numpy
as
np
import
PIL.Image
as
Image
import
PIL.Image
as
Image
...
@@ -36,7 +37,6 @@ import tensorflow as tf
...
@@ -36,7 +37,6 @@ import tensorflow as tf
from
official.vision.detection.utils
import
box_utils
from
official.vision.detection.utils
import
box_utils
from
official.vision.detection.utils.object_detection
import
shape_utils
from
official.vision.detection.utils.object_detection
import
shape_utils
_TITLE_LEFT_MARGIN
=
10
_TITLE_LEFT_MARGIN
=
10
_TITLE_TOP_MARGIN
=
10
_TITLE_TOP_MARGIN
=
10
STANDARD_COLORS
=
[
STANDARD_COLORS
=
[
...
@@ -99,9 +99,9 @@ def visualize_images_with_bounding_boxes(images, box_outputs, step,
...
@@ -99,9 +99,9 @@ def visualize_images_with_bounding_boxes(images, box_outputs, step,
summary_writer
):
summary_writer
):
"""Records subset of evaluation images with bounding boxes."""
"""Records subset of evaluation images with bounding boxes."""
if
not
isinstance
(
images
,
list
):
if
not
isinstance
(
images
,
list
):
logging
.
warning
(
'visualize_images_with_bounding_boxes expects list of '
logging
.
warning
(
'images but received type: %s and value: %s'
,
'visualize_images_with_bounding_boxes expects list of '
type
(
images
),
images
)
'images but received type: %s and value: %s'
,
type
(
images
),
images
)
return
return
image_shape
=
tf
.
shape
(
images
[
0
])
image_shape
=
tf
.
shape
(
images
[
0
])
...
@@ -140,11 +140,11 @@ def draw_bounding_box_on_image_array(image,
...
@@ -140,11 +140,11 @@ def draw_bounding_box_on_image_array(image,
xmax: xmax of bounding box.
xmax: xmax of bounding box.
color: color to draw bounding box. Default is red.
color: color to draw bounding box. Default is red.
thickness: line thickness. Default value is 4.
thickness: line thickness. Default value is 4.
display_str_list: list of strings to display in box
display_str_list: list of strings to display in box
(each to be shown on its
(each to be shown on its
own line).
own line).
use_normalized_coordinates: If True (default), treat coordinates
use_normalized_coordinates: If True (default), treat coordinates
ymin, xmin,
ymin, xmin,
ymax, xmax as relative to the image. Otherwise treat
ymax, xmax as relative to the image. Otherwise treat
coordinates as
coordinates as
absolute.
absolute.
"""
"""
image_pil
=
Image
.
fromarray
(
np
.
uint8
(
image
)).
convert
(
'RGB'
)
image_pil
=
Image
.
fromarray
(
np
.
uint8
(
image
)).
convert
(
'RGB'
)
draw_bounding_box_on_image
(
image_pil
,
ymin
,
xmin
,
ymax
,
xmax
,
color
,
draw_bounding_box_on_image
(
image_pil
,
ymin
,
xmin
,
ymax
,
xmax
,
color
,
...
@@ -180,11 +180,11 @@ def draw_bounding_box_on_image(image,
...
@@ -180,11 +180,11 @@ def draw_bounding_box_on_image(image,
xmax: xmax of bounding box.
xmax: xmax of bounding box.
color: color to draw bounding box. Default is red.
color: color to draw bounding box. Default is red.
thickness: line thickness. Default value is 4.
thickness: line thickness. Default value is 4.
display_str_list: list of strings to display in box
display_str_list: list of strings to display in box
(each to be shown on its
(each to be shown on its
own line).
own line).
use_normalized_coordinates: If True (default), treat coordinates
use_normalized_coordinates: If True (default), treat coordinates
ymin, xmin,
ymin, xmin,
ymax, xmax as relative to the image. Otherwise treat
ymax, xmax as relative to the image. Otherwise treat
coordinates as
coordinates as
absolute.
absolute.
"""
"""
draw
=
ImageDraw
.
Draw
(
image
)
draw
=
ImageDraw
.
Draw
(
image
)
im_width
,
im_height
=
image
.
size
im_width
,
im_height
=
image
.
size
...
@@ -193,8 +193,10 @@ def draw_bounding_box_on_image(image,
...
@@ -193,8 +193,10 @@ def draw_bounding_box_on_image(image,
ymin
*
im_height
,
ymax
*
im_height
)
ymin
*
im_height
,
ymax
*
im_height
)
else
:
else
:
(
left
,
right
,
top
,
bottom
)
=
(
xmin
,
xmax
,
ymin
,
ymax
)
(
left
,
right
,
top
,
bottom
)
=
(
xmin
,
xmax
,
ymin
,
ymax
)
draw
.
line
([(
left
,
top
),
(
left
,
bottom
),
(
right
,
bottom
),
draw
.
line
([(
left
,
top
),
(
left
,
bottom
),
(
right
,
bottom
),
(
right
,
top
),
(
right
,
top
),
(
left
,
top
)],
width
=
thickness
,
fill
=
color
)
(
left
,
top
)],
width
=
thickness
,
fill
=
color
)
try
:
try
:
font
=
ImageFont
.
truetype
(
'arial.ttf'
,
24
)
font
=
ImageFont
.
truetype
(
'arial.ttf'
,
24
)
except
IOError
:
except
IOError
:
...
@@ -215,15 +217,13 @@ def draw_bounding_box_on_image(image,
...
@@ -215,15 +217,13 @@ def draw_bounding_box_on_image(image,
for
display_str
in
display_str_list
[::
-
1
]:
for
display_str
in
display_str_list
[::
-
1
]:
text_width
,
text_height
=
font
.
getsize
(
display_str
)
text_width
,
text_height
=
font
.
getsize
(
display_str
)
margin
=
np
.
ceil
(
0.05
*
text_height
)
margin
=
np
.
ceil
(
0.05
*
text_height
)
draw
.
rectangle
(
draw
.
rectangle
([(
left
,
text_bottom
-
text_height
-
2
*
margin
),
[(
left
,
text_bottom
-
text_height
-
2
*
margin
),
(
left
+
text_width
,
(
left
+
text_width
,
text_bottom
)],
text_bottom
)],
fill
=
color
)
fill
=
color
)
draw
.
text
((
left
+
margin
,
text_bottom
-
text_height
-
margin
),
draw
.
text
(
display_str
,
(
left
+
margin
,
text_bottom
-
text_height
-
margin
),
fill
=
'black'
,
display_str
,
font
=
font
)
fill
=
'black'
,
font
=
font
)
text_bottom
-=
text_height
-
2
*
margin
text_bottom
-=
text_height
-
2
*
margin
...
@@ -236,15 +236,13 @@ def draw_bounding_boxes_on_image_array(image,
...
@@ -236,15 +236,13 @@ def draw_bounding_boxes_on_image_array(image,
Args:
Args:
image: a numpy array object.
image: a numpy array object.
boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax).
boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax).
The
The
coordinates are in normalized format between [0, 1].
coordinates are in normalized format between [0, 1].
color: color to draw bounding box. Default is red.
color: color to draw bounding box. Default is red.
thickness: line thickness. Default value is 4.
thickness: line thickness. Default value is 4.
display_str_list_list: list of list of strings.
display_str_list_list: list of list of strings. a list of strings for each
a list of strings for each bounding box.
bounding box. The reason to pass a list of strings for a bounding box is
The reason to pass a list of strings for a
that it might contain multiple labels.
bounding box is that it might contain
multiple labels.
Raises:
Raises:
ValueError: if boxes is not a [N, 4] array
ValueError: if boxes is not a [N, 4] array
...
@@ -264,15 +262,13 @@ def draw_bounding_boxes_on_image(image,
...
@@ -264,15 +262,13 @@ def draw_bounding_boxes_on_image(image,
Args:
Args:
image: a PIL.Image object.
image: a PIL.Image object.
boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax).
boxes: a 2 dimensional numpy array of [N, 4]: (ymin, xmin, ymax, xmax).
The
The
coordinates are in normalized format between [0, 1].
coordinates are in normalized format between [0, 1].
color: color to draw bounding box. Default is red.
color: color to draw bounding box. Default is red.
thickness: line thickness. Default value is 4.
thickness: line thickness. Default value is 4.
display_str_list_list: list of list of strings.
display_str_list_list: list of list of strings. a list of strings for each
a list of strings for each bounding box.
bounding box. The reason to pass a list of strings for a bounding box is
The reason to pass a list of strings for a
that it might contain multiple labels.
bounding box is that it might contain
multiple labels.
Raises:
Raises:
ValueError: if boxes is not a [N, 4] array
ValueError: if boxes is not a [N, 4] array
...
@@ -319,8 +315,9 @@ def _visualize_boxes_and_keypoints(image, boxes, classes, scores, keypoints,
...
@@ -319,8 +315,9 @@ def _visualize_boxes_and_keypoints(image, boxes, classes, scores, keypoints,
**
kwargs
)
**
kwargs
)
def
_visualize_boxes_and_masks_and_keypoints
(
def
_visualize_boxes_and_masks_and_keypoints
(
image
,
boxes
,
classes
,
scores
,
image
,
boxes
,
classes
,
scores
,
masks
,
keypoints
,
category_index
,
**
kwargs
):
masks
,
keypoints
,
category_index
,
**
kwargs
):
return
visualize_boxes_and_labels_on_image_array
(
return
visualize_boxes_and_labels_on_image_array
(
image
,
image
,
boxes
,
boxes
,
...
@@ -374,8 +371,8 @@ def draw_bounding_boxes_on_image_tensors(images,
...
@@ -374,8 +371,8 @@ def draw_bounding_boxes_on_image_tensors(images,
max_boxes_to_draw: Maximum number of boxes to draw on an image. Default 20.
max_boxes_to_draw: Maximum number of boxes to draw on an image. Default 20.
min_score_thresh: Minimum score threshold for visualization. Default 0.2.
min_score_thresh: Minimum score threshold for visualization. Default 0.2.
use_normalized_coordinates: Whether to assume boxes and kepoints are in
use_normalized_coordinates: Whether to assume boxes and kepoints are in
normalized coordinates (as opposed to absolute coordiantes).
normalized coordinates (as opposed to absolute coordiantes).
Default is
Default is
True.
True.
Returns:
Returns:
4D image tensor of type uint8, with boxes drawn on top.
4D image tensor of type uint8, with boxes drawn on top.
...
@@ -432,17 +429,15 @@ def draw_bounding_boxes_on_image_tensors(images,
...
@@ -432,17 +429,15 @@ def draw_bounding_boxes_on_image_tensors(images,
_visualize_boxes
,
_visualize_boxes
,
category_index
=
category_index
,
category_index
=
category_index
,
**
visualization_keyword_args
)
**
visualization_keyword_args
)
elems
=
[
elems
=
[
true_shapes
,
original_shapes
,
images
,
boxes
,
classes
,
scores
]
true_shapes
,
original_shapes
,
images
,
boxes
,
classes
,
scores
]
def
draw_boxes
(
image_and_detections
):
def
draw_boxes
(
image_and_detections
):
"""Draws boxes on image."""
"""Draws boxes on image."""
true_shape
=
image_and_detections
[
0
]
true_shape
=
image_and_detections
[
0
]
original_shape
=
image_and_detections
[
1
]
original_shape
=
image_and_detections
[
1
]
if
true_image_shape
is
not
None
:
if
true_image_shape
is
not
None
:
image
=
shape_utils
.
pad_or_clip_nd
(
image
=
shape_utils
.
pad_or_clip_nd
(
image_and_detections
[
2
],
image_and_detections
[
2
],
[
true_shape
[
0
],
true_shape
[
1
],
3
])
[
true_shape
[
0
],
true_shape
[
1
],
3
])
if
original_image_spatial_shape
is
not
None
:
if
original_image_spatial_shape
is
not
None
:
image_and_detections
[
2
]
=
_resize_original_image
(
image
,
original_shape
)
image_and_detections
[
2
]
=
_resize_original_image
(
image
,
original_shape
)
...
@@ -500,7 +495,8 @@ def draw_keypoints_on_image(image,
...
@@ -500,7 +495,8 @@ def draw_keypoints_on_image(image,
for
keypoint_x
,
keypoint_y
in
zip
(
keypoints_x
,
keypoints_y
):
for
keypoint_x
,
keypoint_y
in
zip
(
keypoints_x
,
keypoints_y
):
draw
.
ellipse
([(
keypoint_x
-
radius
,
keypoint_y
-
radius
),
draw
.
ellipse
([(
keypoint_x
-
radius
,
keypoint_y
-
radius
),
(
keypoint_x
+
radius
,
keypoint_y
+
radius
)],
(
keypoint_x
+
radius
,
keypoint_y
+
radius
)],
outline
=
color
,
fill
=
color
)
outline
=
color
,
fill
=
color
)
def
draw_mask_on_image_array
(
image
,
mask
,
color
=
'red'
,
alpha
=
0.4
):
def
draw_mask_on_image_array
(
image
,
mask
,
color
=
'red'
,
alpha
=
0.4
):
...
@@ -508,8 +504,8 @@ def draw_mask_on_image_array(image, mask, color='red', alpha=0.4):
...
@@ -508,8 +504,8 @@ def draw_mask_on_image_array(image, mask, color='red', alpha=0.4):
Args:
Args:
image: uint8 numpy array with shape (img_height, img_height, 3)
image: uint8 numpy array with shape (img_height, img_height, 3)
mask: a uint8 numpy array of shape (img_height, img_height) with
mask: a uint8 numpy array of shape (img_height, img_height) with
values
values
between either 0 or 1.
between either 0 or 1.
color: color to draw the keypoints with. Default is red.
color: color to draw the keypoints with. Default is red.
alpha: transparency value between 0 and 1. (default: 0.4)
alpha: transparency value between 0 and 1. (default: 0.4)
...
@@ -531,7 +527,7 @@ def draw_mask_on_image_array(image, mask, color='red', alpha=0.4):
...
@@ -531,7 +527,7 @@ def draw_mask_on_image_array(image, mask, color='red', alpha=0.4):
solid_color
=
np
.
expand_dims
(
solid_color
=
np
.
expand_dims
(
np
.
ones_like
(
mask
),
axis
=
2
)
*
np
.
reshape
(
list
(
rgb
),
[
1
,
1
,
3
])
np
.
ones_like
(
mask
),
axis
=
2
)
*
np
.
reshape
(
list
(
rgb
),
[
1
,
1
,
3
])
pil_solid_color
=
Image
.
fromarray
(
np
.
uint8
(
solid_color
)).
convert
(
'RGBA'
)
pil_solid_color
=
Image
.
fromarray
(
np
.
uint8
(
solid_color
)).
convert
(
'RGBA'
)
pil_mask
=
Image
.
fromarray
(
np
.
uint8
(
255.0
*
alpha
*
mask
)).
convert
(
'L'
)
pil_mask
=
Image
.
fromarray
(
np
.
uint8
(
255.0
*
alpha
*
mask
)).
convert
(
'L'
)
pil_image
=
Image
.
composite
(
pil_solid_color
,
pil_image
,
pil_mask
)
pil_image
=
Image
.
composite
(
pil_solid_color
,
pil_image
,
pil_mask
)
np
.
copyto
(
image
,
np
.
array
(
pil_image
.
convert
(
'RGB'
)))
np
.
copyto
(
image
,
np
.
array
(
pil_image
.
convert
(
'RGB'
)))
...
@@ -565,21 +561,20 @@ def visualize_boxes_and_labels_on_image_array(
...
@@ -565,21 +561,20 @@ def visualize_boxes_and_labels_on_image_array(
boxes: a numpy array of shape [N, 4]
boxes: a numpy array of shape [N, 4]
classes: a numpy array of shape [N]. Note that class indices are 1-based,
classes: a numpy array of shape [N]. Note that class indices are 1-based,
and match the keys in the label map.
and match the keys in the label map.
scores: a numpy array of shape [N] or None. If scores=None, then
scores: a numpy array of shape [N] or None. If scores=None, then
this
this
function assumes that the boxes to be plotted are groundtruth
function assumes that the boxes to be plotted are groundtruth
boxes and
boxes and
plot all boxes as black with no classes or scores.
plot all boxes as black with no classes or scores.
category_index: a dict containing category dictionaries (each holding
category_index: a dict containing category dictionaries (each holding
category index `id` and category name `name`) keyed by category indices.
category index `id` and category name `name`) keyed by category indices.
instance_masks: a numpy array of shape [N, image_height, image_width] with
instance_masks: a numpy array of shape [N, image_height, image_width] with
values ranging between 0 and 1, can be None.
values ranging between 0 and 1, can be None.
instance_boundaries: a numpy array of shape [N, image_height, image_width]
instance_boundaries: a numpy array of shape [N, image_height, image_width]
with values ranging between 0 and 1, can be None.
with values ranging between 0 and 1, can be None.
keypoints: a numpy array of shape [N, num_keypoints, 2], can
keypoints: a numpy array of shape [N, num_keypoints, 2], can be None
be None
use_normalized_coordinates: whether boxes is to be interpreted as normalized
use_normalized_coordinates: whether boxes is to be interpreted as
coordinates or not.
normalized coordinates or not.
max_boxes_to_draw: maximum number of boxes to visualize. If None, draw all
max_boxes_to_draw: maximum number of boxes to visualize. If None, draw
boxes.
all boxes.
min_score_thresh: minimum score threshold for a box to be visualized
min_score_thresh: minimum score threshold for a box to be visualized
agnostic_mode: boolean (default: False) controlling whether to evaluate in
agnostic_mode: boolean (default: False) controlling whether to evaluate in
class-agnostic mode or not. This mode will display scores but ignore
class-agnostic mode or not. This mode will display scores but ignore
...
@@ -624,32 +619,25 @@ def visualize_boxes_and_labels_on_image_array(
...
@@ -624,32 +619,25 @@ def visualize_boxes_and_labels_on_image_array(
display_str
=
str
(
class_name
)
display_str
=
str
(
class_name
)
if
not
skip_scores
:
if
not
skip_scores
:
if
not
display_str
:
if
not
display_str
:
display_str
=
'{}%'
.
format
(
int
(
100
*
scores
[
i
]))
display_str
=
'{}%'
.
format
(
int
(
100
*
scores
[
i
]))
else
:
else
:
display_str
=
'{}: {}%'
.
format
(
display_str
,
int
(
100
*
scores
[
i
]))
display_str
=
'{}: {}%'
.
format
(
display_str
,
int
(
100
*
scores
[
i
]))
box_to_display_str_map
[
box
].
append
(
display_str
)
box_to_display_str_map
[
box
].
append
(
display_str
)
if
agnostic_mode
:
if
agnostic_mode
:
box_to_color_map
[
box
]
=
'DarkOrange'
box_to_color_map
[
box
]
=
'DarkOrange'
else
:
else
:
box_to_color_map
[
box
]
=
STANDARD_COLORS
[
box_to_color_map
[
box
]
=
STANDARD_COLORS
[
classes
[
i
]
%
classes
[
i
]
%
len
(
STANDARD_COLORS
)]
len
(
STANDARD_COLORS
)]
# Draw all boxes onto image.
# Draw all boxes onto image.
for
box
,
color
in
box_to_color_map
.
items
():
for
box
,
color
in
box_to_color_map
.
items
():
ymin
,
xmin
,
ymax
,
xmax
=
box
ymin
,
xmin
,
ymax
,
xmax
=
box
if
instance_masks
is
not
None
:
if
instance_masks
is
not
None
:
draw_mask_on_image_array
(
draw_mask_on_image_array
(
image
,
image
,
box_to_instance_masks_map
[
box
],
color
=
color
)
box_to_instance_masks_map
[
box
],
color
=
color
)
if
instance_boundaries
is
not
None
:
if
instance_boundaries
is
not
None
:
draw_mask_on_image_array
(
draw_mask_on_image_array
(
image
,
image
,
box_to_instance_boundaries_map
[
box
],
color
=
'red'
,
alpha
=
1.0
)
box_to_instance_boundaries_map
[
box
],
color
=
'red'
,
alpha
=
1.0
)
draw_bounding_box_on_image_array
(
draw_bounding_box_on_image_array
(
image
,
image
,
ymin
,
ymin
,
...
@@ -681,13 +669,15 @@ def add_cdf_image_summary(values, name):
...
@@ -681,13 +669,15 @@ def add_cdf_image_summary(values, name):
values: a 1-D float32 tensor containing the values.
values: a 1-D float32 tensor containing the values.
name: name for the image summary.
name: name for the image summary.
"""
"""
def
cdf_plot
(
values
):
def
cdf_plot
(
values
):
"""Numpy function to plot CDF."""
"""Numpy function to plot CDF."""
normalized_values
=
values
/
np
.
sum
(
values
)
normalized_values
=
values
/
np
.
sum
(
values
)
sorted_values
=
np
.
sort
(
normalized_values
)
sorted_values
=
np
.
sort
(
normalized_values
)
cumulative_values
=
np
.
cumsum
(
sorted_values
)
cumulative_values
=
np
.
cumsum
(
sorted_values
)
fraction_of_examples
=
(
np
.
arange
(
cumulative_values
.
size
,
dtype
=
np
.
float32
)
fraction_of_examples
=
(
/
cumulative_values
.
size
)
np
.
arange
(
cumulative_values
.
size
,
dtype
=
np
.
float32
)
/
cumulative_values
.
size
)
fig
=
plt
.
figure
(
frameon
=
False
)
fig
=
plt
.
figure
(
frameon
=
False
)
ax
=
fig
.
add_subplot
(
'111'
)
ax
=
fig
.
add_subplot
(
'111'
)
ax
.
plot
(
fraction_of_examples
,
cumulative_values
)
ax
.
plot
(
fraction_of_examples
,
cumulative_values
)
...
@@ -695,8 +685,9 @@ def add_cdf_image_summary(values, name):
...
@@ -695,8 +685,9 @@ def add_cdf_image_summary(values, name):
ax
.
set_xlabel
(
'fraction of examples'
)
ax
.
set_xlabel
(
'fraction of examples'
)
fig
.
canvas
.
draw
()
fig
.
canvas
.
draw
()
width
,
height
=
fig
.
get_size_inches
()
*
fig
.
get_dpi
()
width
,
height
=
fig
.
get_size_inches
()
*
fig
.
get_dpi
()
image
=
np
.
fromstring
(
fig
.
canvas
.
tostring_rgb
(),
dtype
=
'uint8'
).
reshape
(
image
=
np
.
fromstring
(
1
,
int
(
height
),
int
(
width
),
3
)
fig
.
canvas
.
tostring_rgb
(),
dtype
=
'uint8'
).
reshape
(
1
,
int
(
height
),
int
(
width
),
3
)
return
image
return
image
cdf_plot
=
tf
.
compat
.
v1
.
py_func
(
cdf_plot
,
[
values
],
tf
.
uint8
)
cdf_plot
=
tf
.
compat
.
v1
.
py_func
(
cdf_plot
,
[
values
],
tf
.
uint8
)
...
@@ -725,8 +716,8 @@ def add_hist_image_summary(values, bins, name):
...
@@ -725,8 +716,8 @@ def add_hist_image_summary(values, bins, name):
fig
.
canvas
.
draw
()
fig
.
canvas
.
draw
()
width
,
height
=
fig
.
get_size_inches
()
*
fig
.
get_dpi
()
width
,
height
=
fig
.
get_size_inches
()
*
fig
.
get_dpi
()
image
=
np
.
fromstring
(
image
=
np
.
fromstring
(
fig
.
canvas
.
tostring_rgb
(),
dtype
=
'uint8'
).
reshape
(
fig
.
canvas
.
tostring_rgb
(),
1
,
int
(
height
),
int
(
width
),
3
)
dtype
=
'uint8'
).
reshape
(
1
,
int
(
height
),
int
(
width
),
3
)
return
image
return
image
hist_plot
=
tf
.
compat
.
v1
.
py_func
(
hist_plot
,
[
values
,
bins
],
tf
.
uint8
)
hist_plot
=
tf
.
compat
.
v1
.
py_func
(
hist_plot
,
[
values
,
bins
],
tf
.
uint8
)
...
...
official/vision/image_classification/augment.py
View file @
88253ce5
...
@@ -24,6 +24,7 @@ from __future__ import division
...
@@ -24,6 +24,7 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
math
import
math
import
tensorflow
as
tf
import
tensorflow
as
tf
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Text
,
Tuple
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Text
,
Tuple
...
@@ -120,10 +121,8 @@ def _convert_translation_to_transform(translations: tf.Tensor) -> tf.Tensor:
...
@@ -120,10 +121,8 @@ def _convert_translation_to_transform(translations: tf.Tensor) -> tf.Tensor:
)
)
def
_convert_angles_to_transform
(
def
_convert_angles_to_transform
(
angles
:
tf
.
Tensor
,
image_width
:
tf
.
Tensor
,
angles
:
tf
.
Tensor
,
image_height
:
tf
.
Tensor
)
->
tf
.
Tensor
:
image_width
:
tf
.
Tensor
,
image_height
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Converts an angle or angles to a projective transform.
"""Converts an angle or angles to a projective transform.
Args:
Args:
...
@@ -173,9 +172,7 @@ def transform(image: tf.Tensor, transforms) -> tf.Tensor:
...
@@ -173,9 +172,7 @@ def transform(image: tf.Tensor, transforms) -> tf.Tensor:
transforms
=
transforms
[
None
]
transforms
=
transforms
[
None
]
image
=
to_4d
(
image
)
image
=
to_4d
(
image
)
image
=
image_ops
.
transform
(
image
=
image_ops
.
transform
(
images
=
image
,
images
=
image
,
transforms
=
transforms
,
interpolation
=
'nearest'
)
transforms
=
transforms
,
interpolation
=
'nearest'
)
return
from_4d
(
image
,
original_ndims
)
return
from_4d
(
image
,
original_ndims
)
...
@@ -216,9 +213,8 @@ def rotate(image: tf.Tensor, degrees: float) -> tf.Tensor:
...
@@ -216,9 +213,8 @@ def rotate(image: tf.Tensor, degrees: float) -> tf.Tensor:
image_height
=
tf
.
cast
(
tf
.
shape
(
image
)[
1
],
tf
.
float32
)
image_height
=
tf
.
cast
(
tf
.
shape
(
image
)[
1
],
tf
.
float32
)
image_width
=
tf
.
cast
(
tf
.
shape
(
image
)[
2
],
tf
.
float32
)
image_width
=
tf
.
cast
(
tf
.
shape
(
image
)[
2
],
tf
.
float32
)
transforms
=
_convert_angles_to_transform
(
angles
=
radians
,
transforms
=
_convert_angles_to_transform
(
image_width
=
image_width
,
angles
=
radians
,
image_width
=
image_width
,
image_height
=
image_height
)
image_height
=
image_height
)
# In practice, we should randomize the rotation degrees by flipping
# In practice, we should randomize the rotation degrees by flipping
# it negatively half the time, but that's done on 'degrees' outside
# it negatively half the time, but that's done on 'degrees' outside
# of the function.
# of the function.
...
@@ -279,11 +275,10 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
...
@@ -279,11 +275,10 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
Args:
Args:
image: An image Tensor of type uint8.
image: An image Tensor of type uint8.
pad_size: Specifies how big the zero mask that will be generated is that
pad_size: Specifies how big the zero mask that will be generated is that is
is applied to the image. The mask will be of size
applied to the image. The mask will be of size (2*pad_size x 2*pad_size).
(2*pad_size x 2*pad_size).
replace: What pixel value to fill in the image in the area that has the
replace: What pixel value to fill in the image in the area that has
cutout mask applied to it.
the cutout mask applied to it.
Returns:
Returns:
An image Tensor that is of type uint8.
An image Tensor that is of type uint8.
...
@@ -293,30 +288,30 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
...
@@ -293,30 +288,30 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
# Sample the center location in the image where the zero mask will be applied.
# Sample the center location in the image where the zero mask will be applied.
cutout_center_height
=
tf
.
random
.
uniform
(
cutout_center_height
=
tf
.
random
.
uniform
(
shape
=
[],
minval
=
0
,
maxval
=
image_height
,
shape
=
[],
minval
=
0
,
maxval
=
image_height
,
dtype
=
tf
.
int32
)
dtype
=
tf
.
int32
)
cutout_center_width
=
tf
.
random
.
uniform
(
cutout_center_width
=
tf
.
random
.
uniform
(
shape
=
[],
minval
=
0
,
maxval
=
image_width
,
shape
=
[],
minval
=
0
,
maxval
=
image_width
,
dtype
=
tf
.
int32
)
dtype
=
tf
.
int32
)
lower_pad
=
tf
.
maximum
(
0
,
cutout_center_height
-
pad_size
)
lower_pad
=
tf
.
maximum
(
0
,
cutout_center_height
-
pad_size
)
upper_pad
=
tf
.
maximum
(
0
,
image_height
-
cutout_center_height
-
pad_size
)
upper_pad
=
tf
.
maximum
(
0
,
image_height
-
cutout_center_height
-
pad_size
)
left_pad
=
tf
.
maximum
(
0
,
cutout_center_width
-
pad_size
)
left_pad
=
tf
.
maximum
(
0
,
cutout_center_width
-
pad_size
)
right_pad
=
tf
.
maximum
(
0
,
image_width
-
cutout_center_width
-
pad_size
)
right_pad
=
tf
.
maximum
(
0
,
image_width
-
cutout_center_width
-
pad_size
)
cutout_shape
=
[
image_height
-
(
lower_pad
+
upper_pad
),
cutout_shape
=
[
image_width
-
(
left_pad
+
right_pad
)]
image_height
-
(
lower_pad
+
upper_pad
),
image_width
-
(
left_pad
+
right_pad
)
]
padding_dims
=
[[
lower_pad
,
upper_pad
],
[
left_pad
,
right_pad
]]
padding_dims
=
[[
lower_pad
,
upper_pad
],
[
left_pad
,
right_pad
]]
mask
=
tf
.
pad
(
mask
=
tf
.
pad
(
tf
.
zeros
(
cutout_shape
,
dtype
=
image
.
dtype
),
tf
.
zeros
(
cutout_shape
,
dtype
=
image
.
dtype
),
padding_dims
,
constant_values
=
1
)
padding_dims
,
constant_values
=
1
)
mask
=
tf
.
expand_dims
(
mask
,
-
1
)
mask
=
tf
.
expand_dims
(
mask
,
-
1
)
mask
=
tf
.
tile
(
mask
,
[
1
,
1
,
3
])
mask
=
tf
.
tile
(
mask
,
[
1
,
1
,
3
])
image
=
tf
.
where
(
image
=
tf
.
where
(
tf
.
equal
(
mask
,
0
),
tf
.
equal
(
mask
,
0
),
tf
.
ones_like
(
image
,
dtype
=
image
.
dtype
)
*
replace
,
tf
.
ones_like
(
image
,
dtype
=
image
.
dtype
)
*
replace
,
image
)
image
)
return
image
return
image
...
@@ -398,8 +393,8 @@ def shear_x(image: tf.Tensor, level: float, replace: int) -> tf.Tensor:
...
@@ -398,8 +393,8 @@ def shear_x(image: tf.Tensor, level: float, replace: int) -> tf.Tensor:
# with a matrix form of:
# with a matrix form of:
# [1 level
# [1 level
# 0 1].
# 0 1].
image
=
transform
(
image
=
wrap
(
image
),
image
=
transform
(
transforms
=
[
1.
,
level
,
0.
,
0.
,
1.
,
0.
,
0.
,
0.
])
image
=
wrap
(
image
),
transforms
=
[
1.
,
level
,
0.
,
0.
,
1.
,
0.
,
0.
,
0.
])
return
unwrap
(
image
,
replace
)
return
unwrap
(
image
,
replace
)
...
@@ -409,8 +404,8 @@ def shear_y(image: tf.Tensor, level: float, replace: int) -> tf.Tensor:
...
@@ -409,8 +404,8 @@ def shear_y(image: tf.Tensor, level: float, replace: int) -> tf.Tensor:
# with a matrix form of:
# with a matrix form of:
# [1 0
# [1 0
# level 1].
# level 1].
image
=
transform
(
image
=
wrap
(
image
),
image
=
transform
(
transforms
=
[
1.
,
0.
,
0.
,
level
,
1.
,
0.
,
0.
,
0.
])
image
=
wrap
(
image
),
transforms
=
[
1.
,
0.
,
0.
,
level
,
1.
,
0.
,
0.
,
0.
])
return
unwrap
(
image
,
replace
)
return
unwrap
(
image
,
replace
)
...
@@ -460,9 +455,9 @@ def sharpness(image: tf.Tensor, factor: float) -> tf.Tensor:
...
@@ -460,9 +455,9 @@ def sharpness(image: tf.Tensor, factor: float) -> tf.Tensor:
# Make image 4D for conv operation.
# Make image 4D for conv operation.
image
=
tf
.
expand_dims
(
image
,
0
)
image
=
tf
.
expand_dims
(
image
,
0
)
# SMOOTH PIL Kernel.
# SMOOTH PIL Kernel.
kernel
=
tf
.
constant
(
kernel
=
tf
.
constant
(
[[
1
,
1
,
1
],
[
1
,
5
,
1
],
[
1
,
1
,
1
]],
[[
1
,
1
,
1
],
[
1
,
5
,
1
],
[
1
,
1
,
1
]],
dtype
=
tf
.
float32
,
dtype
=
tf
.
float32
,
shape
=
[
3
,
3
,
1
,
1
])
/
13.
shape
=
[
3
,
3
,
1
,
1
])
/
13.
# Tile across channel dimension.
# Tile across channel dimension.
kernel
=
tf
.
tile
(
kernel
,
[
1
,
1
,
3
,
1
])
kernel
=
tf
.
tile
(
kernel
,
[
1
,
1
,
3
,
1
])
strides
=
[
1
,
1
,
1
,
1
]
strides
=
[
1
,
1
,
1
,
1
]
...
@@ -484,6 +479,7 @@ def sharpness(image: tf.Tensor, factor: float) -> tf.Tensor:
...
@@ -484,6 +479,7 @@ def sharpness(image: tf.Tensor, factor: float) -> tf.Tensor:
def
equalize
(
image
:
tf
.
Tensor
)
->
tf
.
Tensor
:
def
equalize
(
image
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Implements Equalize function from PIL using TF ops."""
"""Implements Equalize function from PIL using TF ops."""
def
scale_channel
(
im
,
c
):
def
scale_channel
(
im
,
c
):
"""Scale the data in the channel to implement equalize."""
"""Scale the data in the channel to implement equalize."""
im
=
tf
.
cast
(
im
[:,
:,
c
],
tf
.
int32
)
im
=
tf
.
cast
(
im
[:,
:,
c
],
tf
.
int32
)
...
@@ -507,9 +503,9 @@ def equalize(image: tf.Tensor) -> tf.Tensor:
...
@@ -507,9 +503,9 @@ def equalize(image: tf.Tensor) -> tf.Tensor:
# If step is zero, return the original image. Otherwise, build
# If step is zero, return the original image. Otherwise, build
# lut from the full histogram and step and then index from it.
# lut from the full histogram and step and then index from it.
result
=
tf
.
cond
(
tf
.
equal
(
step
,
0
),
result
=
tf
.
cond
(
lambda
:
im
,
tf
.
equal
(
step
,
0
),
lambda
:
im
,
lambda
:
tf
.
gather
(
build_lut
(
histo
,
step
),
im
))
lambda
:
tf
.
gather
(
build_lut
(
histo
,
step
),
im
))
return
tf
.
cast
(
result
,
tf
.
uint8
)
return
tf
.
cast
(
result
,
tf
.
uint8
)
...
@@ -582,7 +578,7 @@ def _randomly_negate_tensor(tensor):
...
@@ -582,7 +578,7 @@ def _randomly_negate_tensor(tensor):
def
_rotate_level_to_arg
(
level
:
float
):
def
_rotate_level_to_arg
(
level
:
float
):
level
=
(
level
/
_MAX_LEVEL
)
*
30.
level
=
(
level
/
_MAX_LEVEL
)
*
30.
level
=
_randomly_negate_tensor
(
level
)
level
=
_randomly_negate_tensor
(
level
)
return
(
level
,)
return
(
level
,)
...
@@ -597,18 +593,18 @@ def _shrink_level_to_arg(level: float):
...
@@ -597,18 +593,18 @@ def _shrink_level_to_arg(level: float):
def
_enhance_level_to_arg
(
level
:
float
):
def
_enhance_level_to_arg
(
level
:
float
):
return
((
level
/
_MAX_LEVEL
)
*
1.8
+
0.1
,)
return
((
level
/
_MAX_LEVEL
)
*
1.8
+
0.1
,)
def
_shear_level_to_arg
(
level
:
float
):
def
_shear_level_to_arg
(
level
:
float
):
level
=
(
level
/
_MAX_LEVEL
)
*
0.3
level
=
(
level
/
_MAX_LEVEL
)
*
0.3
# Flip level to negative with 50% chance.
# Flip level to negative with 50% chance.
level
=
_randomly_negate_tensor
(
level
)
level
=
_randomly_negate_tensor
(
level
)
return
(
level
,)
return
(
level
,)
def
_translate_level_to_arg
(
level
:
float
,
translate_const
:
float
):
def
_translate_level_to_arg
(
level
:
float
,
translate_const
:
float
):
level
=
(
level
/
_MAX_LEVEL
)
*
float
(
translate_const
)
level
=
(
level
/
_MAX_LEVEL
)
*
float
(
translate_const
)
# Flip level to negative with 50% chance.
# Flip level to negative with 50% chance.
level
=
_randomly_negate_tensor
(
level
)
level
=
_randomly_negate_tensor
(
level
)
return
(
level
,)
return
(
level
,)
...
@@ -618,20 +614,15 @@ def _mult_to_arg(level: float, multiplier: float = 1.):
...
@@ -618,20 +614,15 @@ def _mult_to_arg(level: float, multiplier: float = 1.):
return
(
int
((
level
/
_MAX_LEVEL
)
*
multiplier
),)
return
(
int
((
level
/
_MAX_LEVEL
)
*
multiplier
),)
def
_apply_func_with_prob
(
func
:
Any
,
def
_apply_func_with_prob
(
func
:
Any
,
image
:
tf
.
Tensor
,
args
:
Any
,
prob
:
float
):
image
:
tf
.
Tensor
,
args
:
Any
,
prob
:
float
):
"""Apply `func` to image w/ `args` as input with probability `prob`."""
"""Apply `func` to image w/ `args` as input with probability `prob`."""
assert
isinstance
(
args
,
tuple
)
assert
isinstance
(
args
,
tuple
)
# Apply the function with probability `prob`.
# Apply the function with probability `prob`.
should_apply_op
=
tf
.
cast
(
should_apply_op
=
tf
.
cast
(
tf
.
floor
(
tf
.
random
.
uniform
([],
dtype
=
tf
.
float32
)
+
prob
),
tf
.
bool
)
tf
.
floor
(
tf
.
random
.
uniform
([],
dtype
=
tf
.
float32
)
+
prob
),
tf
.
bool
)
augmented_image
=
tf
.
cond
(
augmented_image
=
tf
.
cond
(
should_apply_op
,
lambda
:
func
(
image
,
*
args
),
should_apply_op
,
lambda
:
image
)
lambda
:
func
(
image
,
*
args
),
lambda
:
image
)
return
augmented_image
return
augmented_image
...
@@ -709,11 +700,8 @@ def level_to_arg(cutout_const: float, translate_const: float):
...
@@ -709,11 +700,8 @@ def level_to_arg(cutout_const: float, translate_const: float):
return
args
return
args
def
_parse_policy_info
(
name
:
Text
,
def
_parse_policy_info
(
name
:
Text
,
prob
:
float
,
level
:
float
,
prob
:
float
,
replace_value
:
List
[
int
],
cutout_const
:
float
,
level
:
float
,
replace_value
:
List
[
int
],
cutout_const
:
float
,
translate_const
:
float
)
->
Tuple
[
Any
,
float
,
Any
]:
translate_const
:
float
)
->
Tuple
[
Any
,
float
,
Any
]:
"""Return the function that corresponds to `name` and update `level` param."""
"""Return the function that corresponds to `name` and update `level` param."""
func
=
NAME_TO_FUNC
[
name
]
func
=
NAME_TO_FUNC
[
name
]
...
@@ -969,8 +957,9 @@ class RandAugment(ImageAugment):
...
@@ -969,8 +957,9 @@ class RandAugment(ImageAugment):
min_prob
,
max_prob
=
0.2
,
0.8
min_prob
,
max_prob
=
0.2
,
0.8
for
_
in
range
(
self
.
num_layers
):
for
_
in
range
(
self
.
num_layers
):
op_to_select
=
tf
.
random
.
uniform
(
op_to_select
=
tf
.
random
.
uniform
([],
[],
maxval
=
len
(
self
.
available_ops
)
+
1
,
dtype
=
tf
.
int32
)
maxval
=
len
(
self
.
available_ops
)
+
1
,
dtype
=
tf
.
int32
)
branch_fns
=
[]
branch_fns
=
[]
for
(
i
,
op_name
)
in
enumerate
(
self
.
available_ops
):
for
(
i
,
op_name
)
in
enumerate
(
self
.
available_ops
):
...
@@ -978,11 +967,8 @@ class RandAugment(ImageAugment):
...
@@ -978,11 +967,8 @@ class RandAugment(ImageAugment):
minval
=
min_prob
,
minval
=
min_prob
,
maxval
=
max_prob
,
maxval
=
max_prob
,
dtype
=
tf
.
float32
)
dtype
=
tf
.
float32
)
func
,
_
,
args
=
_parse_policy_info
(
op_name
,
func
,
_
,
args
=
_parse_policy_info
(
op_name
,
prob
,
self
.
magnitude
,
prob
,
replace_value
,
self
.
cutout_const
,
self
.
magnitude
,
replace_value
,
self
.
cutout_const
,
self
.
translate_const
)
self
.
translate_const
)
branch_fns
.
append
((
branch_fns
.
append
((
i
,
i
,
...
@@ -991,9 +977,10 @@ class RandAugment(ImageAugment):
...
@@ -991,9 +977,10 @@ class RandAugment(ImageAugment):
image
,
*
selected_args
)))
image
,
*
selected_args
)))
# pylint:enable=g-long-lambda
# pylint:enable=g-long-lambda
image
=
tf
.
switch_case
(
branch_index
=
op_to_select
,
image
=
tf
.
switch_case
(
branch_fns
=
branch_fns
,
branch_index
=
op_to_select
,
default
=
lambda
:
tf
.
identity
(
image
))
branch_fns
=
branch_fns
,
default
=
lambda
:
tf
.
identity
(
image
))
image
=
tf
.
cast
(
image
,
dtype
=
input_image_type
)
image
=
tf
.
cast
(
image
,
dtype
=
input_image_type
)
return
image
return
image
official/vision/image_classification/augment_test.py
View file @
88253ce5
...
@@ -49,24 +49,15 @@ class TransformsTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -49,24 +49,15 @@ class TransformsTest(parameterized.TestCase, tf.test.TestCase):
def
test_transform
(
self
,
dtype
):
def
test_transform
(
self
,
dtype
):
image
=
tf
.
constant
([[
1
,
2
],
[
3
,
4
]],
dtype
=
dtype
)
image
=
tf
.
constant
([[
1
,
2
],
[
3
,
4
]],
dtype
=
dtype
)
self
.
assertAllEqual
(
augment
.
transform
(
image
,
transforms
=
[
1
]
*
8
),
self
.
assertAllEqual
(
[[
4
,
4
],
[
4
,
4
]])
augment
.
transform
(
image
,
transforms
=
[
1
]
*
8
),
[[
4
,
4
],
[
4
,
4
]])
def
test_translate
(
self
,
dtype
):
def
test_translate
(
self
,
dtype
):
image
=
tf
.
constant
(
image
=
tf
.
constant
(
[[
1
,
0
,
1
,
0
],
[[
1
,
0
,
1
,
0
],
[
0
,
1
,
0
,
1
],
[
1
,
0
,
1
,
0
],
[
0
,
1
,
0
,
1
]],
dtype
=
dtype
)
[
0
,
1
,
0
,
1
],
[
1
,
0
,
1
,
0
],
[
0
,
1
,
0
,
1
]],
dtype
=
dtype
)
translations
=
[
-
1
,
-
1
]
translations
=
[
-
1
,
-
1
]
translated
=
augment
.
translate
(
image
=
image
,
translated
=
augment
.
translate
(
image
=
image
,
translations
=
translations
)
translations
=
translations
)
expected
=
[[
1
,
0
,
1
,
1
],
[
0
,
1
,
0
,
0
],
[
1
,
0
,
1
,
1
],
[
1
,
0
,
1
,
1
]]
expected
=
[
[
1
,
0
,
1
,
1
],
[
0
,
1
,
0
,
0
],
[
1
,
0
,
1
,
1
],
[
1
,
0
,
1
,
1
]]
self
.
assertAllEqual
(
translated
,
expected
)
self
.
assertAllEqual
(
translated
,
expected
)
def
test_translate_shapes
(
self
,
dtype
):
def
test_translate_shapes
(
self
,
dtype
):
...
@@ -85,9 +76,7 @@ class TransformsTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -85,9 +76,7 @@ class TransformsTest(parameterized.TestCase, tf.test.TestCase):
image
=
tf
.
reshape
(
tf
.
cast
(
tf
.
range
(
9
),
dtype
),
(
3
,
3
))
image
=
tf
.
reshape
(
tf
.
cast
(
tf
.
range
(
9
),
dtype
),
(
3
,
3
))
rotation
=
90.
rotation
=
90.
transformed
=
augment
.
rotate
(
image
=
image
,
degrees
=
rotation
)
transformed
=
augment
.
rotate
(
image
=
image
,
degrees
=
rotation
)
expected
=
[[
2
,
5
,
8
],
expected
=
[[
2
,
5
,
8
],
[
1
,
4
,
7
],
[
0
,
3
,
6
]]
[
1
,
4
,
7
],
[
0
,
3
,
6
]]
self
.
assertAllEqual
(
transformed
,
expected
)
self
.
assertAllEqual
(
transformed
,
expected
)
def
test_rotate_shapes
(
self
,
dtype
):
def
test_rotate_shapes
(
self
,
dtype
):
...
@@ -129,15 +118,13 @@ class AutoaugmentTest(tf.test.TestCase):
...
@@ -129,15 +118,13 @@ class AutoaugmentTest(tf.test.TestCase):
image
=
tf
.
ones
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
image
=
tf
.
ones
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
for
op_name
in
augment
.
NAME_TO_FUNC
:
for
op_name
in
augment
.
NAME_TO_FUNC
:
func
,
_
,
args
=
augment
.
_parse_policy_info
(
op_name
,
func
,
_
,
args
=
augment
.
_parse_policy_info
(
op_name
,
prob
,
magnitude
,
prob
,
replace_value
,
cutout_const
,
magnitude
,
replace_value
,
cutout_const
,
translate_const
)
translate_const
)
image
=
func
(
image
,
*
args
)
image
=
func
(
image
,
*
args
)
self
.
assertEqual
((
224
,
224
,
3
),
image
.
shape
)
self
.
assertEqual
((
224
,
224
,
3
),
image
.
shape
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
official/vision/image_classification/callbacks.py
View file @
88253ce5
...
@@ -21,6 +21,7 @@ from __future__ import print_function
...
@@ -21,6 +21,7 @@ from __future__ import print_function
import
os
import
os
from
typing
import
Any
,
List
,
MutableMapping
,
Text
from
typing
import
Any
,
List
,
MutableMapping
,
Text
from
absl
import
logging
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -43,8 +44,9 @@ def get_callbacks(model_checkpoint: bool = True,
...
@@ -43,8 +44,9 @@ def get_callbacks(model_checkpoint: bool = True,
callbacks
=
[]
callbacks
=
[]
if
model_checkpoint
:
if
model_checkpoint
:
ckpt_full_path
=
os
.
path
.
join
(
model_dir
,
'model.ckpt-{epoch:04d}'
)
ckpt_full_path
=
os
.
path
.
join
(
model_dir
,
'model.ckpt-{epoch:04d}'
)
callbacks
.
append
(
tf
.
keras
.
callbacks
.
ModelCheckpoint
(
callbacks
.
append
(
ckpt_full_path
,
save_weights_only
=
True
,
verbose
=
1
))
tf
.
keras
.
callbacks
.
ModelCheckpoint
(
ckpt_full_path
,
save_weights_only
=
True
,
verbose
=
1
))
if
include_tensorboard
:
if
include_tensorboard
:
callbacks
.
append
(
callbacks
.
append
(
CustomTensorBoard
(
CustomTensorBoard
(
...
@@ -61,13 +63,14 @@ def get_callbacks(model_checkpoint: bool = True,
...
@@ -61,13 +63,14 @@ def get_callbacks(model_checkpoint: bool = True,
if
apply_moving_average
:
if
apply_moving_average
:
# Save moving average model to a different file so that
# Save moving average model to a different file so that
# we can resume training from a checkpoint
# we can resume training from a checkpoint
ckpt_full_path
=
os
.
path
.
join
(
ckpt_full_path
=
os
.
path
.
join
(
model_dir
,
'average'
,
model_dir
,
'average'
,
'model.ckpt-{epoch:04d}'
)
'model.ckpt-{epoch:04d}'
)
callbacks
.
append
(
AverageModelCheckpoint
(
callbacks
.
append
(
update_weights
=
False
,
AverageModelCheckpoint
(
filepath
=
ckpt_full_path
,
update_weights
=
False
,
save_weights_only
=
True
,
filepath
=
ckpt_full_path
,
verbose
=
1
))
save_weights_only
=
True
,
verbose
=
1
))
callbacks
.
append
(
MovingAverageCallback
())
callbacks
.
append
(
MovingAverageCallback
())
return
callbacks
return
callbacks
...
@@ -175,16 +178,13 @@ class MovingAverageCallback(tf.keras.callbacks.Callback):
...
@@ -175,16 +178,13 @@ class MovingAverageCallback(tf.keras.callbacks.Callback):
**kwargs: Any additional callback arguments.
**kwargs: Any additional callback arguments.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
overwrite_weights_on_train_end
:
bool
=
False
,
**
kwargs
):
overwrite_weights_on_train_end
:
bool
=
False
,
**
kwargs
):
super
(
MovingAverageCallback
,
self
).
__init__
(
**
kwargs
)
super
(
MovingAverageCallback
,
self
).
__init__
(
**
kwargs
)
self
.
overwrite_weights_on_train_end
=
overwrite_weights_on_train_end
self
.
overwrite_weights_on_train_end
=
overwrite_weights_on_train_end
def
set_model
(
self
,
model
:
tf
.
keras
.
Model
):
def
set_model
(
self
,
model
:
tf
.
keras
.
Model
):
super
(
MovingAverageCallback
,
self
).
set_model
(
model
)
super
(
MovingAverageCallback
,
self
).
set_model
(
model
)
assert
isinstance
(
self
.
model
.
optimizer
,
assert
isinstance
(
self
.
model
.
optimizer
,
optimizer_factory
.
MovingAverage
)
optimizer_factory
.
MovingAverage
)
self
.
model
.
optimizer
.
shadow_copy
(
self
.
model
)
self
.
model
.
optimizer
.
shadow_copy
(
self
.
model
)
def
on_test_begin
(
self
,
logs
:
MutableMapping
[
Text
,
Any
]
=
None
):
def
on_test_begin
(
self
,
logs
:
MutableMapping
[
Text
,
Any
]
=
None
):
...
@@ -204,40 +204,30 @@ class AverageModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):
...
@@ -204,40 +204,30 @@ class AverageModelCheckpoint(tf.keras.callbacks.ModelCheckpoint):
Taken from tfa.callbacks.AverageModelCheckpoint.
Taken from tfa.callbacks.AverageModelCheckpoint.
Attributes:
Attributes:
update_weights: If True, assign the moving average weights
update_weights: If True, assign the moving average weights to the model, and
to the model, and save them. If False, keep the old
save them. If False, keep the old non-averaged weights, but the saved
non-averaged weights, but the saved model uses the
model uses the average weights. See `tf.keras.callbacks.ModelCheckpoint`
average weights.
for the other args.
See `tf.keras.callbacks.ModelCheckpoint` for the other args.
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
update_weights
:
bool
,
update_weights
:
bool
,
filepath
:
str
,
filepath
:
str
,
monitor
:
str
=
'val_loss'
,
monitor
:
str
=
'val_loss'
,
verbose
:
int
=
0
,
verbose
:
int
=
0
,
save_best_only
:
bool
=
False
,
save_best_only
:
bool
=
False
,
save_weights_only
:
bool
=
False
,
save_weights_only
:
bool
=
False
,
mode
:
str
=
'auto'
,
mode
:
str
=
'auto'
,
save_freq
:
str
=
'epoch'
,
save_freq
:
str
=
'epoch'
,
**
kwargs
):
**
kwargs
):
self
.
update_weights
=
update_weights
self
.
update_weights
=
update_weights
super
().
__init__
(
super
().
__init__
(
filepath
,
monitor
,
verbose
,
save_best_only
,
filepath
,
save_weights_only
,
mode
,
save_freq
,
**
kwargs
)
monitor
,
verbose
,
save_best_only
,
save_weights_only
,
mode
,
save_freq
,
**
kwargs
)
def
set_model
(
self
,
model
):
def
set_model
(
self
,
model
):
if
not
isinstance
(
model
.
optimizer
,
optimizer_factory
.
MovingAverage
):
if
not
isinstance
(
model
.
optimizer
,
optimizer_factory
.
MovingAverage
):
raise
TypeError
(
raise
TypeError
(
'AverageModelCheckpoint is only used when training'
'AverageModelCheckpoint is only used when training'
'with MovingAverage'
)
'with MovingAverage'
)
return
super
().
set_model
(
model
)
return
super
().
set_model
(
model
)
def
_save_model
(
self
,
epoch
,
logs
):
def
_save_model
(
self
,
epoch
,
logs
):
...
...
official/vision/image_classification/classifier_trainer.py
View file @
88253ce5
...
@@ -41,7 +41,7 @@ from official.vision.image_classification.resnet import resnet_model
...
@@ -41,7 +41,7 @@ from official.vision.image_classification.resnet import resnet_model
def
get_models
()
->
Mapping
[
str
,
tf
.
keras
.
Model
]:
def
get_models
()
->
Mapping
[
str
,
tf
.
keras
.
Model
]:
"""Returns the mapping from model type name to Keras model."""
"""Returns the mapping from model type name to Keras model."""
return
{
return
{
'efficientnet'
:
efficientnet_model
.
EfficientNet
.
from_name
,
'efficientnet'
:
efficientnet_model
.
EfficientNet
.
from_name
,
'resnet'
:
resnet_model
.
resnet50
,
'resnet'
:
resnet_model
.
resnet50
,
}
}
...
@@ -55,7 +55,7 @@ def get_dtype_map() -> Mapping[str, tf.dtypes.DType]:
...
@@ -55,7 +55,7 @@ def get_dtype_map() -> Mapping[str, tf.dtypes.DType]:
'float16'
:
tf
.
float16
,
'float16'
:
tf
.
float16
,
'fp32'
:
tf
.
float32
,
'fp32'
:
tf
.
float32
,
'bf16'
:
tf
.
bfloat16
,
'bf16'
:
tf
.
bfloat16
,
}
}
def
_get_metrics
(
one_hot
:
bool
)
->
Mapping
[
Text
,
Any
]:
def
_get_metrics
(
one_hot
:
bool
)
->
Mapping
[
Text
,
Any
]:
...
@@ -63,22 +63,28 @@ def _get_metrics(one_hot: bool) -> Mapping[Text, Any]:
...
@@ -63,22 +63,28 @@ def _get_metrics(one_hot: bool) -> Mapping[Text, Any]:
if
one_hot
:
if
one_hot
:
return
{
return
{
# (name, metric_fn)
# (name, metric_fn)
'acc'
:
tf
.
keras
.
metrics
.
CategoricalAccuracy
(
name
=
'accuracy'
),
'acc'
:
'accuracy'
:
tf
.
keras
.
metrics
.
CategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
CategoricalAccuracy
(
name
=
'accuracy'
),
'top_1'
:
tf
.
keras
.
metrics
.
CategoricalAccuracy
(
name
=
'accuracy'
),
'accuracy'
:
'top_5'
:
tf
.
keras
.
metrics
.
TopKCategoricalAccuracy
(
tf
.
keras
.
metrics
.
CategoricalAccuracy
(
name
=
'accuracy'
),
k
=
5
,
'top_1'
:
name
=
'top_5_accuracy'
),
tf
.
keras
.
metrics
.
CategoricalAccuracy
(
name
=
'accuracy'
),
'top_5'
:
tf
.
keras
.
metrics
.
TopKCategoricalAccuracy
(
k
=
5
,
name
=
'top_5_accuracy'
),
}
}
else
:
else
:
return
{
return
{
# (name, metric_fn)
# (name, metric_fn)
'acc'
:
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'accuracy'
),
'acc'
:
'accuracy'
:
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'accuracy'
),
'top_1'
:
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'accuracy'
),
'accuracy'
:
'top_5'
:
tf
.
keras
.
metrics
.
SparseTopKCategoricalAccuracy
(
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'accuracy'
),
k
=
5
,
'top_1'
:
name
=
'top_5_accuracy'
),
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
name
=
'accuracy'
),
'top_5'
:
tf
.
keras
.
metrics
.
SparseTopKCategoricalAccuracy
(
k
=
5
,
name
=
'top_5_accuracy'
),
}
}
...
@@ -94,8 +100,7 @@ def get_image_size_from_model(
...
@@ -94,8 +100,7 @@ def get_image_size_from_model(
def
_get_dataset_builders
(
params
:
base_configs
.
ExperimentConfig
,
def
_get_dataset_builders
(
params
:
base_configs
.
ExperimentConfig
,
strategy
:
tf
.
distribute
.
Strategy
,
strategy
:
tf
.
distribute
.
Strategy
,
one_hot
:
bool
one_hot
:
bool
)
->
Tuple
[
Any
,
Any
]:
)
->
Tuple
[
Any
,
Any
]:
"""Create and return train and validation dataset builders."""
"""Create and return train and validation dataset builders."""
if
one_hot
:
if
one_hot
:
logging
.
warning
(
'label_smoothing > 0, so datasets will be one hot encoded.'
)
logging
.
warning
(
'label_smoothing > 0, so datasets will be one hot encoded.'
)
...
@@ -107,9 +112,7 @@ def _get_dataset_builders(params: base_configs.ExperimentConfig,
...
@@ -107,9 +112,7 @@ def _get_dataset_builders(params: base_configs.ExperimentConfig,
image_size
=
get_image_size_from_model
(
params
)
image_size
=
get_image_size_from_model
(
params
)
dataset_configs
=
[
dataset_configs
=
[
params
.
train_dataset
,
params
.
validation_dataset
]
params
.
train_dataset
,
params
.
validation_dataset
]
builders
=
[]
builders
=
[]
for
config
in
dataset_configs
:
for
config
in
dataset_configs
:
...
@@ -171,8 +174,7 @@ def _get_params_from_flags(flags_obj: flags.FlagValues):
...
@@ -171,8 +174,7 @@ def _get_params_from_flags(flags_obj: flags.FlagValues):
},
},
}
}
overriding_configs
=
(
flags_obj
.
config_file
,
overriding_configs
=
(
flags_obj
.
config_file
,
flags_obj
.
params_override
,
flags_obj
.
params_override
,
flags_overrides
)
flags_overrides
)
pp
=
pprint
.
PrettyPrinter
()
pp
=
pprint
.
PrettyPrinter
()
...
@@ -190,8 +192,7 @@ def _get_params_from_flags(flags_obj: flags.FlagValues):
...
@@ -190,8 +192,7 @@ def _get_params_from_flags(flags_obj: flags.FlagValues):
return
params
return
params
def
resume_from_checkpoint
(
model
:
tf
.
keras
.
Model
,
def
resume_from_checkpoint
(
model
:
tf
.
keras
.
Model
,
model_dir
:
str
,
model_dir
:
str
,
train_steps
:
int
)
->
int
:
train_steps
:
int
)
->
int
:
"""Resumes from the latest checkpoint, if possible.
"""Resumes from the latest checkpoint, if possible.
...
@@ -226,8 +227,7 @@ def resume_from_checkpoint(model: tf.keras.Model,
...
@@ -226,8 +227,7 @@ def resume_from_checkpoint(model: tf.keras.Model,
def
initialize
(
params
:
base_configs
.
ExperimentConfig
,
def
initialize
(
params
:
base_configs
.
ExperimentConfig
,
dataset_builder
:
dataset_factory
.
DatasetBuilder
):
dataset_builder
:
dataset_factory
.
DatasetBuilder
):
"""Initializes backend related initializations."""
"""Initializes backend related initializations."""
keras_utils
.
set_session_config
(
keras_utils
.
set_session_config
(
enable_xla
=
params
.
runtime
.
enable_xla
)
enable_xla
=
params
.
runtime
.
enable_xla
)
performance
.
set_mixed_precision_policy
(
dataset_builder
.
dtype
,
performance
.
set_mixed_precision_policy
(
dataset_builder
.
dtype
,
get_loss_scale
(
params
))
get_loss_scale
(
params
))
if
tf
.
config
.
list_physical_devices
(
'GPU'
):
if
tf
.
config
.
list_physical_devices
(
'GPU'
):
...
@@ -244,7 +244,8 @@ def initialize(params: base_configs.ExperimentConfig,
...
@@ -244,7 +244,8 @@ def initialize(params: base_configs.ExperimentConfig,
per_gpu_thread_count
=
params
.
runtime
.
per_gpu_thread_count
,
per_gpu_thread_count
=
params
.
runtime
.
per_gpu_thread_count
,
gpu_thread_mode
=
params
.
runtime
.
gpu_thread_mode
,
gpu_thread_mode
=
params
.
runtime
.
gpu_thread_mode
,
num_gpus
=
params
.
runtime
.
num_gpus
,
num_gpus
=
params
.
runtime
.
num_gpus
,
datasets_num_private_threads
=
params
.
runtime
.
dataset_num_private_threads
)
# pylint:disable=line-too-long
datasets_num_private_threads
=
params
.
runtime
.
dataset_num_private_threads
)
# pylint:disable=line-too-long
if
params
.
runtime
.
batchnorm_spatial_persistent
:
if
params
.
runtime
.
batchnorm_spatial_persistent
:
os
.
environ
[
'TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'
]
=
'1'
os
.
environ
[
'TF_USE_CUDNN_BATCHNORM_SPATIAL_PERSISTENT'
]
=
'1'
...
@@ -253,9 +254,7 @@ def define_classifier_flags():
...
@@ -253,9 +254,7 @@ def define_classifier_flags():
"""Defines common flags for image classification."""
"""Defines common flags for image classification."""
hyperparams_flags
.
initialize_common_flags
()
hyperparams_flags
.
initialize_common_flags
()
flags
.
DEFINE_string
(
flags
.
DEFINE_string
(
'data_dir'
,
'data_dir'
,
default
=
None
,
help
=
'The location of the input data.'
)
default
=
None
,
help
=
'The location of the input data.'
)
flags
.
DEFINE_string
(
flags
.
DEFINE_string
(
'mode'
,
'mode'
,
default
=
None
,
default
=
None
,
...
@@ -278,8 +277,7 @@ def define_classifier_flags():
...
@@ -278,8 +277,7 @@ def define_classifier_flags():
help
=
'The interval of steps between logging of batch level stats.'
)
help
=
'The interval of steps between logging of batch level stats.'
)
def
serialize_config
(
params
:
base_configs
.
ExperimentConfig
,
def
serialize_config
(
params
:
base_configs
.
ExperimentConfig
,
model_dir
:
str
):
model_dir
:
str
):
"""Serializes and saves the experiment config."""
"""Serializes and saves the experiment config."""
params_save_path
=
os
.
path
.
join
(
model_dir
,
'params.yaml'
)
params_save_path
=
os
.
path
.
join
(
model_dir
,
'params.yaml'
)
logging
.
info
(
'Saving experiment configuration to %s'
,
params_save_path
)
logging
.
info
(
'Saving experiment configuration to %s'
,
params_save_path
)
...
@@ -293,9 +291,8 @@ def train_and_eval(
...
@@ -293,9 +291,8 @@ def train_and_eval(
"""Runs the train and eval path using compile/fit."""
"""Runs the train and eval path using compile/fit."""
logging
.
info
(
'Running train and eval.'
)
logging
.
info
(
'Running train and eval.'
)
distribution_utils
.
configure_cluster
(
distribution_utils
.
configure_cluster
(
params
.
runtime
.
worker_hosts
,
params
.
runtime
.
worker_hosts
,
params
.
runtime
.
task_index
)
params
.
runtime
.
task_index
)
# Note: for TPUs, strategy and scope should be created before the dataset
# Note: for TPUs, strategy and scope should be created before the dataset
strategy
=
strategy_override
or
distribution_utils
.
get_distribution_strategy
(
strategy
=
strategy_override
or
distribution_utils
.
get_distribution_strategy
(
...
@@ -313,8 +310,9 @@ def train_and_eval(
...
@@ -313,8 +310,9 @@ def train_and_eval(
one_hot
=
label_smoothing
and
label_smoothing
>
0
one_hot
=
label_smoothing
and
label_smoothing
>
0
builders
=
_get_dataset_builders
(
params
,
strategy
,
one_hot
)
builders
=
_get_dataset_builders
(
params
,
strategy
,
one_hot
)
datasets
=
[
builder
.
build
(
strategy
)
datasets
=
[
if
builder
else
None
for
builder
in
builders
]
builder
.
build
(
strategy
)
if
builder
else
None
for
builder
in
builders
]
# Unpack datasets and builders based on train/val/test splits
# Unpack datasets and builders based on train/val/test splits
train_builder
,
validation_builder
=
builders
# pylint: disable=unbalanced-tuple-unpacking
train_builder
,
validation_builder
=
builders
# pylint: disable=unbalanced-tuple-unpacking
...
@@ -351,16 +349,16 @@ def train_and_eval(
...
@@ -351,16 +349,16 @@ def train_and_eval(
label_smoothing
=
params
.
model
.
loss
.
label_smoothing
)
label_smoothing
=
params
.
model
.
loss
.
label_smoothing
)
else
:
else
:
loss_obj
=
tf
.
keras
.
losses
.
SparseCategoricalCrossentropy
()
loss_obj
=
tf
.
keras
.
losses
.
SparseCategoricalCrossentropy
()
model
.
compile
(
optimizer
=
optimizer
,
model
.
compile
(
loss
=
loss_obj
,
optimizer
=
optimizer
,
metrics
=
metrics
,
loss
=
loss_obj
,
experimental_steps_per_execution
=
steps_per_loop
)
metrics
=
metrics
,
experimental_steps_per_execution
=
steps_per_loop
)
initial_epoch
=
0
initial_epoch
=
0
if
params
.
train
.
resume_checkpoint
:
if
params
.
train
.
resume_checkpoint
:
initial_epoch
=
resume_from_checkpoint
(
model
=
model
,
initial_epoch
=
resume_from_checkpoint
(
model_dir
=
params
.
model_dir
,
model
=
model
,
model_dir
=
params
.
model_dir
,
train_steps
=
train_steps
)
train_steps
=
train_steps
)
callbacks
=
custom_callbacks
.
get_callbacks
(
callbacks
=
custom_callbacks
.
get_callbacks
(
model_checkpoint
=
params
.
train
.
callbacks
.
enable_checkpoint_and_export
,
model_checkpoint
=
params
.
train
.
callbacks
.
enable_checkpoint_and_export
,
...
@@ -399,9 +397,7 @@ def train_and_eval(
...
@@ -399,9 +397,7 @@ def train_and_eval(
validation_dataset
,
steps
=
validation_steps
,
verbose
=
2
)
validation_dataset
,
steps
=
validation_steps
,
verbose
=
2
)
# TODO(dankondratyuk): eval and save final test accuracy
# TODO(dankondratyuk): eval and save final test accuracy
stats
=
common
.
build_stats
(
history
,
stats
=
common
.
build_stats
(
history
,
validation_output
,
callbacks
)
validation_output
,
callbacks
)
return
stats
return
stats
...
...
official/vision/image_classification/classifier_trainer_test.py
View file @
88253ce5
...
@@ -105,14 +105,13 @@ def get_trivial_model(num_classes: int) -> tf.keras.Model:
...
@@ -105,14 +105,13 @@ def get_trivial_model(num_classes: int) -> tf.keras.Model:
lr
=
0.01
lr
=
0.01
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
learning_rate
=
lr
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
learning_rate
=
lr
)
loss_obj
=
tf
.
keras
.
losses
.
SparseCategoricalCrossentropy
()
loss_obj
=
tf
.
keras
.
losses
.
SparseCategoricalCrossentropy
()
model
.
compile
(
optimizer
=
optimizer
,
model
.
compile
(
optimizer
=
optimizer
,
loss
=
loss_obj
,
run_eagerly
=
True
)
loss
=
loss_obj
,
run_eagerly
=
True
)
return
model
return
model
def
get_trivial_data
()
->
tf
.
data
.
Dataset
:
def
get_trivial_data
()
->
tf
.
data
.
Dataset
:
"""Gets trivial data in the ImageNet size."""
"""Gets trivial data in the ImageNet size."""
def
generate_data
(
_
)
->
tf
.
data
.
Dataset
:
def
generate_data
(
_
)
->
tf
.
data
.
Dataset
:
image
=
tf
.
zeros
(
shape
=
(
224
,
224
,
3
),
dtype
=
tf
.
float32
)
image
=
tf
.
zeros
(
shape
=
(
224
,
224
,
3
),
dtype
=
tf
.
float32
)
label
=
tf
.
zeros
([
1
],
dtype
=
tf
.
int32
)
label
=
tf
.
zeros
([
1
],
dtype
=
tf
.
int32
)
...
@@ -120,8 +119,8 @@ def get_trivial_data() -> tf.data.Dataset:
...
@@ -120,8 +119,8 @@ def get_trivial_data() -> tf.data.Dataset:
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
map
(
generate_data
,
dataset
=
dataset
.
map
(
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
generate_data
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
dataset
=
dataset
.
prefetch
(
buffer_size
=
1
).
batch
(
1
)
dataset
=
dataset
.
prefetch
(
buffer_size
=
1
).
batch
(
1
)
return
dataset
return
dataset
...
@@ -165,11 +164,10 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -165,11 +164,10 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
'--mode=train_and_eval'
,
'--mode=train_and_eval'
,
]
]
run
=
functools
.
partial
(
classifier_trainer
.
run
,
run
=
functools
.
partial
(
strategy_override
=
distribution
)
classifier_trainer
.
run
,
strategy_override
=
distribution
)
run_end_to_end
(
main
=
run
,
run_end_to_end
(
extra_flags
=
train_and_eval_flags
,
main
=
run
,
extra_flags
=
train_and_eval_flags
,
model_dir
=
model_dir
)
model_dir
=
model_dir
)
@
combinations
.
generate
(
@
combinations
.
generate
(
combinations
.
combine
(
combinations
.
combine
(
...
@@ -209,29 +207,26 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -209,29 +207,26 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
get_params_override
(
export_params
)
get_params_override
(
export_params
)
]
]
run
=
functools
.
partial
(
classifier_trainer
.
run
,
run
=
functools
.
partial
(
strategy_override
=
distribution
)
classifier_trainer
.
run
,
strategy_override
=
distribution
)
run_end_to_end
(
main
=
run
,
run_end_to_end
(
extra_flags
=
train_and_eval_flags
,
main
=
run
,
extra_flags
=
train_and_eval_flags
,
model_dir
=
model_dir
)
model_dir
=
model_dir
)
run_end_to_end
(
main
=
run
,
extra_flags
=
export_flags
,
model_dir
=
model_dir
)
run_end_to_end
(
main
=
run
,
extra_flags
=
export_flags
,
model_dir
=
model_dir
)
self
.
assertTrue
(
os
.
path
.
exists
(
export_path
))
self
.
assertTrue
(
os
.
path
.
exists
(
export_path
))
@
combinations
.
generate
(
@
combinations
.
generate
(
combinations
.
combine
(
combinations
.
combine
(
distribution
=
[
distribution
=
[
strategy_combinations
.
tpu_strategy
,
strategy_combinations
.
tpu_strategy
,
],
],
model
=
[
model
=
[
'efficientnet'
,
'efficientnet'
,
'resnet'
,
'resnet'
,
],
],
mode
=
'eager'
,
mode
=
'eager'
,
dataset
=
'imagenet'
,
dataset
=
'imagenet'
,
dtype
=
'bfloat16'
,
dtype
=
'bfloat16'
,
))
))
def
test_tpu_train
(
self
,
distribution
,
model
,
dataset
,
dtype
):
def
test_tpu_train
(
self
,
distribution
,
model
,
dataset
,
dtype
):
"""Test train_and_eval and export for Keras classifier models."""
"""Test train_and_eval and export for Keras classifier models."""
# Some parameters are not defined as flags (e.g. cannot run
# Some parameters are not defined as flags (e.g. cannot run
...
@@ -248,11 +243,10 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -248,11 +243,10 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
'--mode=train_and_eval'
,
'--mode=train_and_eval'
,
]
]
run
=
functools
.
partial
(
classifier_trainer
.
run
,
run
=
functools
.
partial
(
strategy_override
=
distribution
)
classifier_trainer
.
run
,
strategy_override
=
distribution
)
run_end_to_end
(
main
=
run
,
run_end_to_end
(
extra_flags
=
train_and_eval_flags
,
main
=
run
,
extra_flags
=
train_and_eval_flags
,
model_dir
=
model_dir
)
model_dir
=
model_dir
)
@
combinations
.
generate
(
distribution_strategy_combinations
())
@
combinations
.
generate
(
distribution_strategy_combinations
())
def
test_end_to_end_invalid_mode
(
self
,
distribution
,
model
,
dataset
):
def
test_end_to_end_invalid_mode
(
self
,
distribution
,
model
,
dataset
):
...
@@ -266,8 +260,8 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -266,8 +260,8 @@ class ClassifierTest(tf.test.TestCase, parameterized.TestCase):
get_params_override
(
basic_params_override
()),
get_params_override
(
basic_params_override
()),
]
]
run
=
functools
.
partial
(
classifier_trainer
.
run
,
run
=
functools
.
partial
(
strategy_override
=
distribution
)
classifier_trainer
.
run
,
strategy_override
=
distribution
)
with
self
.
assertRaises
(
ValueError
):
with
self
.
assertRaises
(
ValueError
):
run_end_to_end
(
main
=
run
,
extra_flags
=
extra_flags
,
model_dir
=
model_dir
)
run_end_to_end
(
main
=
run
,
extra_flags
=
extra_flags
,
model_dir
=
model_dir
)
...
@@ -292,9 +286,7 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
...
@@ -292,9 +286,7 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
model
=
base_configs
.
ModelConfig
(
model
=
base_configs
.
ModelConfig
(
model_params
=
{
model_params
=
{
'model_name'
:
model_name
,
'model_name'
:
model_name
,
},
},))
)
)
size
=
classifier_trainer
.
get_image_size_from_model
(
config
)
size
=
classifier_trainer
.
get_image_size_from_model
(
config
)
self
.
assertEqual
(
size
,
expected
)
self
.
assertEqual
(
size
,
expected
)
...
@@ -306,16 +298,13 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
...
@@ -306,16 +298,13 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
)
)
def
test_get_loss_scale
(
self
,
loss_scale
,
dtype
,
expected
):
def
test_get_loss_scale
(
self
,
loss_scale
,
dtype
,
expected
):
config
=
base_configs
.
ExperimentConfig
(
config
=
base_configs
.
ExperimentConfig
(
runtime
=
base_configs
.
RuntimeConfig
(
runtime
=
base_configs
.
RuntimeConfig
(
loss_scale
=
loss_scale
),
loss_scale
=
loss_scale
),
train_dataset
=
dataset_factory
.
DatasetConfig
(
dtype
=
dtype
))
train_dataset
=
dataset_factory
.
DatasetConfig
(
dtype
=
dtype
))
ls
=
classifier_trainer
.
get_loss_scale
(
config
,
fp16_default
=
128
)
ls
=
classifier_trainer
.
get_loss_scale
(
config
,
fp16_default
=
128
)
self
.
assertEqual
(
ls
,
expected
)
self
.
assertEqual
(
ls
,
expected
)
@
parameterized
.
named_parameters
(
@
parameterized
.
named_parameters
((
'float16'
,
'float16'
),
(
'float16'
,
'float16'
),
(
'bfloat16'
,
'bfloat16'
))
(
'bfloat16'
,
'bfloat16'
)
)
def
test_initialize
(
self
,
dtype
):
def
test_initialize
(
self
,
dtype
):
config
=
base_configs
.
ExperimentConfig
(
config
=
base_configs
.
ExperimentConfig
(
runtime
=
base_configs
.
RuntimeConfig
(
runtime
=
base_configs
.
RuntimeConfig
(
...
@@ -332,6 +321,7 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
...
@@ -332,6 +321,7 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
class
EmptyClass
:
class
EmptyClass
:
pass
pass
fake_ds_builder
=
EmptyClass
()
fake_ds_builder
=
EmptyClass
()
fake_ds_builder
.
dtype
=
dtype
fake_ds_builder
.
dtype
=
dtype
fake_ds_builder
.
config
=
EmptyClass
()
fake_ds_builder
.
config
=
EmptyClass
()
...
@@ -366,9 +356,7 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
...
@@ -366,9 +356,7 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
clean_model
=
get_trivial_model
(
10
)
clean_model
=
get_trivial_model
(
10
)
weights_before_load
=
copy
.
deepcopy
(
clean_model
.
get_weights
())
weights_before_load
=
copy
.
deepcopy
(
clean_model
.
get_weights
())
initial_epoch
=
classifier_trainer
.
resume_from_checkpoint
(
initial_epoch
=
classifier_trainer
.
resume_from_checkpoint
(
model
=
clean_model
,
model
=
clean_model
,
model_dir
=
model_dir
,
train_steps
=
train_steps
)
model_dir
=
model_dir
,
train_steps
=
train_steps
)
self
.
assertEqual
(
initial_epoch
,
1
)
self
.
assertEqual
(
initial_epoch
,
1
)
self
.
assertNotAllClose
(
weights_before_load
,
clean_model
.
get_weights
())
self
.
assertNotAllClose
(
weights_before_load
,
clean_model
.
get_weights
())
...
@@ -383,5 +371,6 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
...
@@ -383,5 +371,6 @@ class UtilTests(parameterized.TestCase, tf.test.TestCase):
self
.
assertTrue
(
os
.
path
.
exists
(
saved_params_path
))
self
.
assertTrue
(
os
.
path
.
exists
(
saved_params_path
))
tf
.
io
.
gfile
.
rmtree
(
model_dir
)
tf
.
io
.
gfile
.
rmtree
(
model_dir
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
official/vision/image_classification/configs/base_configs.py
View file @
88253ce5
...
@@ -18,7 +18,6 @@ from __future__ import absolute_import
...
@@ -18,7 +18,6 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
from
typing
import
Any
,
List
,
Mapping
,
Optional
from
typing
import
Any
,
List
,
Mapping
,
Optional
import
dataclasses
import
dataclasses
...
...
official/vision/image_classification/configs/configs.py
View file @
88253ce5
...
@@ -37,7 +37,6 @@ class EfficientNetImageNetConfig(base_configs.ExperimentConfig):
...
@@ -37,7 +37,6 @@ class EfficientNetImageNetConfig(base_configs.ExperimentConfig):
train: A `TrainConfig` instance.
train: A `TrainConfig` instance.
evaluation: An `EvalConfig` instance.
evaluation: An `EvalConfig` instance.
model: A `ModelConfig` instance.
model: A `ModelConfig` instance.
"""
"""
export
:
base_configs
.
ExportConfig
=
base_configs
.
ExportConfig
()
export
:
base_configs
.
ExportConfig
=
base_configs
.
ExportConfig
()
runtime
:
base_configs
.
RuntimeConfig
=
base_configs
.
RuntimeConfig
()
runtime
:
base_configs
.
RuntimeConfig
=
base_configs
.
RuntimeConfig
()
...
@@ -49,16 +48,15 @@ class EfficientNetImageNetConfig(base_configs.ExperimentConfig):
...
@@ -49,16 +48,15 @@ class EfficientNetImageNetConfig(base_configs.ExperimentConfig):
resume_checkpoint
=
True
,
resume_checkpoint
=
True
,
epochs
=
500
,
epochs
=
500
,
steps
=
None
,
steps
=
None
,
callbacks
=
base_configs
.
CallbacksConfig
(
enable_checkpoint_and_export
=
True
,
callbacks
=
base_configs
.
CallbacksConfig
(
enable_tensorboard
=
True
),
enable_checkpoint_and_export
=
True
,
enable_tensorboard
=
True
),
metrics
=
[
'accuracy'
,
'top_5'
],
metrics
=
[
'accuracy'
,
'top_5'
],
time_history
=
base_configs
.
TimeHistoryConfig
(
log_steps
=
100
),
time_history
=
base_configs
.
TimeHistoryConfig
(
log_steps
=
100
),
tensorboard
=
base_configs
.
TensorboardConfig
(
track_lr
=
True
,
tensorboard
=
base_configs
.
TensorboardConfig
(
write_model_weights
=
False
),
track_lr
=
True
,
write_model_weights
=
False
),
set_epoch_loop
=
False
)
set_epoch_loop
=
False
)
evaluation
:
base_configs
.
EvalConfig
=
base_configs
.
EvalConfig
(
evaluation
:
base_configs
.
EvalConfig
=
base_configs
.
EvalConfig
(
epochs_between_evals
=
1
,
epochs_between_evals
=
1
,
steps
=
None
)
steps
=
None
)
model
:
base_configs
.
ModelConfig
=
\
model
:
base_configs
.
ModelConfig
=
\
efficientnet_config
.
EfficientNetModelConfig
()
efficientnet_config
.
EfficientNetModelConfig
()
...
@@ -82,16 +80,15 @@ class ResNetImagenetConfig(base_configs.ExperimentConfig):
...
@@ -82,16 +80,15 @@ class ResNetImagenetConfig(base_configs.ExperimentConfig):
resume_checkpoint
=
True
,
resume_checkpoint
=
True
,
epochs
=
90
,
epochs
=
90
,
steps
=
None
,
steps
=
None
,
callbacks
=
base_configs
.
CallbacksConfig
(
enable_checkpoint_and_export
=
True
,
callbacks
=
base_configs
.
CallbacksConfig
(
enable_tensorboard
=
True
),
enable_checkpoint_and_export
=
True
,
enable_tensorboard
=
True
),
metrics
=
[
'accuracy'
,
'top_5'
],
metrics
=
[
'accuracy'
,
'top_5'
],
time_history
=
base_configs
.
TimeHistoryConfig
(
log_steps
=
100
),
time_history
=
base_configs
.
TimeHistoryConfig
(
log_steps
=
100
),
tensorboard
=
base_configs
.
TensorboardConfig
(
track_lr
=
True
,
tensorboard
=
base_configs
.
TensorboardConfig
(
write_model_weights
=
False
),
track_lr
=
True
,
write_model_weights
=
False
),
set_epoch_loop
=
False
)
set_epoch_loop
=
False
)
evaluation
:
base_configs
.
EvalConfig
=
base_configs
.
EvalConfig
(
evaluation
:
base_configs
.
EvalConfig
=
base_configs
.
EvalConfig
(
epochs_between_evals
=
1
,
epochs_between_evals
=
1
,
steps
=
None
)
steps
=
None
)
model
:
base_configs
.
ModelConfig
=
resnet_config
.
ResNetModelConfig
()
model
:
base_configs
.
ModelConfig
=
resnet_config
.
ResNetModelConfig
()
...
@@ -109,10 +106,8 @@ def get_config(model: str, dataset: str) -> base_configs.ExperimentConfig:
...
@@ -109,10 +106,8 @@ def get_config(model: str, dataset: str) -> base_configs.ExperimentConfig:
if
dataset
not
in
dataset_model_config_map
:
if
dataset
not
in
dataset_model_config_map
:
raise
KeyError
(
'Invalid dataset received. Received: {}. Supported '
raise
KeyError
(
'Invalid dataset received. Received: {}. Supported '
'datasets include: {}'
.
format
(
'datasets include: {}'
.
format
(
dataset
,
dataset
,
', '
.
join
(
dataset_model_config_map
.
keys
())))
', '
.
join
(
dataset_model_config_map
.
keys
())))
raise
KeyError
(
'Invalid model received. Received: {}. Supported models for'
raise
KeyError
(
'Invalid model received. Received: {}. Supported models for'
'{} include: {}'
.
format
(
'{} include: {}'
.
format
(
model
,
model
,
dataset
,
dataset
,
', '
.
join
(
dataset_model_config_map
[
dataset
].
keys
())))
', '
.
join
(
dataset_model_config_map
[
dataset
].
keys
())))
official/vision/image_classification/dataset_factory.py
View file @
88253ce5
...
@@ -21,6 +21,7 @@ from __future__ import print_function
...
@@ -21,6 +21,7 @@ from __future__ import print_function
import
os
import
os
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
Mapping
,
Union
from
typing
import
Any
,
List
,
Optional
,
Tuple
,
Mapping
,
Union
from
absl
import
logging
from
absl
import
logging
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -30,7 +31,6 @@ from official.modeling.hyperparams import base_config
...
@@ -30,7 +31,6 @@ from official.modeling.hyperparams import base_config
from
official.vision.image_classification
import
augment
from
official.vision.image_classification
import
augment
from
official.vision.image_classification
import
preprocessing
from
official.vision.image_classification
import
preprocessing
AUGMENTERS
=
{
AUGMENTERS
=
{
'autoaugment'
:
augment
.
AutoAugment
,
'autoaugment'
:
augment
.
AutoAugment
,
'randaugment'
:
augment
.
RandAugment
,
'randaugment'
:
augment
.
RandAugment
,
...
@@ -42,8 +42,8 @@ class AugmentConfig(base_config.Config):
...
@@ -42,8 +42,8 @@ class AugmentConfig(base_config.Config):
"""Configuration for image augmenters.
"""Configuration for image augmenters.
Attributes:
Attributes:
name: The name of the image augmentation to use. Possible options are
name: The name of the image augmentation to use. Possible options are
None
None
(default), 'autoaugment', or 'randaugment'.
(default), 'autoaugment', or 'randaugment'.
params: Any paramaters used to initialize the augmenter.
params: Any paramaters used to initialize the augmenter.
"""
"""
name
:
Optional
[
str
]
=
None
name
:
Optional
[
str
]
=
None
...
@@ -68,17 +68,17 @@ class DatasetConfig(base_config.Config):
...
@@ -68,17 +68,17 @@ class DatasetConfig(base_config.Config):
'tfds' (load using TFDS), 'records' (load from TFRecords), or 'synthetic'
'tfds' (load using TFDS), 'records' (load from TFRecords), or 'synthetic'
(generate dummy synthetic data without reading from files).
(generate dummy synthetic data without reading from files).
split: The split of the dataset. Usually 'train', 'validation', or 'test'.
split: The split of the dataset. Usually 'train', 'validation', or 'test'.
image_size: The size of the image in the dataset. This assumes that
image_size: The size of the image in the dataset. This assumes that
`width`
`width`
== `height`. Set to 'infer' to infer the image size from TFDS
== `height`. Set to 'infer' to infer the image size from TFDS
info. This
info. This
requires `name` to be a registered dataset in TFDS.
requires `name` to be a registered dataset in TFDS.
num_classes: The number of classes given by the dataset. Set to 'infer'
num_classes: The number of classes given by the dataset. Set to 'infer'
to
to
infer the image size from TFDS info. This requires `name` to be a
infer the image size from TFDS info. This requires `name` to be a
registered dataset in TFDS.
registered dataset in TFDS.
num_channels: The number of channels given by the dataset. Set to 'infer'
num_channels: The number of channels given by the dataset. Set to 'infer'
to
to
infer the image size from TFDS info. This requires `name` to be a
infer the image size from TFDS info. This requires `name` to be a
registered dataset in TFDS.
registered dataset in TFDS.
num_examples: The number of examples given by the dataset. Set to 'infer'
num_examples: The number of examples given by the dataset. Set to 'infer'
to
to
infer the image size from TFDS info. This requires `name` to be a
infer the image size from TFDS info. This requires `name` to be a
registered dataset in TFDS.
registered dataset in TFDS.
batch_size: The base batch size for the dataset.
batch_size: The base batch size for the dataset.
use_per_replica_batch_size: Whether to scale the batch size based on
use_per_replica_batch_size: Whether to scale the batch size based on
...
@@ -284,10 +284,10 @@ class DatasetBuilder:
...
@@ -284,10 +284,10 @@ class DatasetBuilder:
"""
"""
if
strategy
:
if
strategy
:
if
strategy
.
num_replicas_in_sync
!=
self
.
config
.
num_devices
:
if
strategy
.
num_replicas_in_sync
!=
self
.
config
.
num_devices
:
logging
.
warn
(
'Passed a strategy with %d devices, but expected'
logging
.
warn
(
'%d devices.'
,
'Passed a strategy with %d devices, but expected'
strategy
.
num_replicas_in_sync
,
'%d devices.'
,
strategy
.
num_replicas_in_sync
,
self
.
config
.
num_devices
)
self
.
config
.
num_devices
)
dataset
=
strategy
.
experimental_distribute_datasets_from_function
(
dataset
=
strategy
.
experimental_distribute_datasets_from_function
(
self
.
_build
)
self
.
_build
)
else
:
else
:
...
@@ -295,8 +295,9 @@ class DatasetBuilder:
...
@@ -295,8 +295,9 @@ class DatasetBuilder:
return
dataset
return
dataset
def
_build
(
self
,
input_context
:
tf
.
distribute
.
InputContext
=
None
def
_build
(
)
->
tf
.
data
.
Dataset
:
self
,
input_context
:
tf
.
distribute
.
InputContext
=
None
)
->
tf
.
data
.
Dataset
:
"""Construct a dataset end-to-end and return it.
"""Construct a dataset end-to-end and return it.
Args:
Args:
...
@@ -328,8 +329,7 @@ class DatasetBuilder:
...
@@ -328,8 +329,7 @@ class DatasetBuilder:
logging
.
info
(
'Using TFDS to load data.'
)
logging
.
info
(
'Using TFDS to load data.'
)
builder
=
tfds
.
builder
(
self
.
config
.
name
,
builder
=
tfds
.
builder
(
self
.
config
.
name
,
data_dir
=
self
.
config
.
data_dir
)
data_dir
=
self
.
config
.
data_dir
)
if
self
.
config
.
download
:
if
self
.
config
.
download
:
builder
.
download_and_prepare
()
builder
.
download_and_prepare
()
...
@@ -380,8 +380,8 @@ class DatasetBuilder:
...
@@ -380,8 +380,8 @@ class DatasetBuilder:
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
map
(
generate_data
,
dataset
=
dataset
.
map
(
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
generate_data
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
return
dataset
def
pipeline
(
self
,
dataset
:
tf
.
data
.
Dataset
)
->
tf
.
data
.
Dataset
:
def
pipeline
(
self
,
dataset
:
tf
.
data
.
Dataset
)
->
tf
.
data
.
Dataset
:
...
@@ -393,14 +393,14 @@ class DatasetBuilder:
...
@@ -393,14 +393,14 @@ class DatasetBuilder:
Returns:
Returns:
A TensorFlow dataset outputting batched images and labels.
A TensorFlow dataset outputting batched images and labels.
"""
"""
if
(
self
.
config
.
builder
!=
'tfds'
and
self
.
input_context
if
(
self
.
config
.
builder
!=
'tfds'
and
self
.
input_context
and
and
self
.
input_context
.
num_input_pipelines
>
1
):
self
.
input_context
.
num_input_pipelines
>
1
):
dataset
=
dataset
.
shard
(
self
.
input_context
.
num_input_pipelines
,
dataset
=
dataset
.
shard
(
self
.
input_context
.
num_input_pipelines
,
self
.
input_context
.
input_pipeline_id
)
self
.
input_context
.
input_pipeline_id
)
logging
.
info
(
'Sharding the dataset: input_pipeline_id=%d '
logging
.
info
(
'num_
input_pipeline
s
=%d'
,
'Sharding the dataset:
input_pipeline
_id
=%d
'
self
.
input_context
.
num_input_pipelines
,
'num_input_pipelines=%d'
,
self
.
input_context
.
num_input_pipelines
,
self
.
input_context
.
input_pipeline_id
)
self
.
input_context
.
input_pipeline_id
)
if
self
.
is_training
and
self
.
config
.
builder
==
'records'
:
if
self
.
is_training
and
self
.
config
.
builder
==
'records'
:
# Shuffle the input files.
# Shuffle the input files.
...
@@ -429,8 +429,8 @@ class DatasetBuilder:
...
@@ -429,8 +429,8 @@ class DatasetBuilder:
preprocess
=
self
.
parse_record
preprocess
=
self
.
parse_record
else
:
else
:
preprocess
=
self
.
preprocess
preprocess
=
self
.
preprocess
dataset
=
dataset
.
map
(
preprocess
,
dataset
=
dataset
.
map
(
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
preprocess
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
if
self
.
input_context
and
self
.
config
.
num_devices
>
1
:
if
self
.
input_context
and
self
.
config
.
num_devices
>
1
:
if
not
self
.
config
.
use_per_replica_batch_size
:
if
not
self
.
config
.
use_per_replica_batch_size
:
...
@@ -444,11 +444,11 @@ class DatasetBuilder:
...
@@ -444,11 +444,11 @@ class DatasetBuilder:
# The batch size of the dataset will be multiplied by the number of
# The batch size of the dataset will be multiplied by the number of
# replicas automatically when strategy.distribute_datasets_from_function
# replicas automatically when strategy.distribute_datasets_from_function
# is called, so we use local batch size here.
# is called, so we use local batch size here.
dataset
=
dataset
.
batch
(
self
.
local_batch_size
,
dataset
=
dataset
.
batch
(
drop_remainder
=
self
.
is_training
)
self
.
local_batch_size
,
drop_remainder
=
self
.
is_training
)
else
:
else
:
dataset
=
dataset
.
batch
(
self
.
global_batch_size
,
dataset
=
dataset
.
batch
(
drop_remainder
=
self
.
is_training
)
self
.
global_batch_size
,
drop_remainder
=
self
.
is_training
)
# Prefetch overlaps in-feed with training
# Prefetch overlaps in-feed with training
dataset
=
dataset
.
prefetch
(
tf
.
data
.
experimental
.
AUTOTUNE
)
dataset
=
dataset
.
prefetch
(
tf
.
data
.
experimental
.
AUTOTUNE
)
...
@@ -470,24 +470,15 @@ class DatasetBuilder:
...
@@ -470,24 +470,15 @@ class DatasetBuilder:
def
parse_record
(
self
,
record
:
tf
.
Tensor
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
def
parse_record
(
self
,
record
:
tf
.
Tensor
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
"""Parse an ImageNet record from a serialized string Tensor."""
"""Parse an ImageNet record from a serialized string Tensor."""
keys_to_features
=
{
keys_to_features
=
{
'image/encoded'
:
'image/encoded'
:
tf
.
io
.
FixedLenFeature
((),
tf
.
string
,
''
),
tf
.
io
.
FixedLenFeature
((),
tf
.
string
,
''
),
'image/format'
:
tf
.
io
.
FixedLenFeature
((),
tf
.
string
,
'jpeg'
),
'image/format'
:
'image/class/label'
:
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
,
-
1
),
tf
.
io
.
FixedLenFeature
((),
tf
.
string
,
'jpeg'
),
'image/class/text'
:
tf
.
io
.
FixedLenFeature
([],
tf
.
string
,
''
),
'image/class/label'
:
'image/object/bbox/xmin'
:
tf
.
io
.
VarLenFeature
(
dtype
=
tf
.
float32
),
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
,
-
1
),
'image/object/bbox/ymin'
:
tf
.
io
.
VarLenFeature
(
dtype
=
tf
.
float32
),
'image/class/text'
:
'image/object/bbox/xmax'
:
tf
.
io
.
VarLenFeature
(
dtype
=
tf
.
float32
),
tf
.
io
.
FixedLenFeature
([],
tf
.
string
,
''
),
'image/object/bbox/ymax'
:
tf
.
io
.
VarLenFeature
(
dtype
=
tf
.
float32
),
'image/object/bbox/xmin'
:
'image/object/class/label'
:
tf
.
io
.
VarLenFeature
(
dtype
=
tf
.
int64
),
tf
.
io
.
VarLenFeature
(
dtype
=
tf
.
float32
),
'image/object/bbox/ymin'
:
tf
.
io
.
VarLenFeature
(
dtype
=
tf
.
float32
),
'image/object/bbox/xmax'
:
tf
.
io
.
VarLenFeature
(
dtype
=
tf
.
float32
),
'image/object/bbox/ymax'
:
tf
.
io
.
VarLenFeature
(
dtype
=
tf
.
float32
),
'image/object/class/label'
:
tf
.
io
.
VarLenFeature
(
dtype
=
tf
.
int64
),
}
}
parsed
=
tf
.
io
.
parse_single_example
(
record
,
keys_to_features
)
parsed
=
tf
.
io
.
parse_single_example
(
record
,
keys_to_features
)
...
@@ -502,8 +493,8 @@ class DatasetBuilder:
...
@@ -502,8 +493,8 @@ class DatasetBuilder:
return
image
,
label
return
image
,
label
def
preprocess
(
self
,
image
:
tf
.
Tensor
,
label
:
tf
.
Tensor
def
preprocess
(
self
,
image
:
tf
.
Tensor
,
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
label
:
tf
.
Tensor
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
"""Apply image preprocessing and augmentation to the image and label."""
"""Apply image preprocessing and augmentation to the image and label."""
if
self
.
is_training
:
if
self
.
is_training
:
image
=
preprocessing
.
preprocess_for_train
(
image
=
preprocessing
.
preprocess_for_train
(
...
...
official/vision/image_classification/efficientnet/common_modules.py
View file @
88253ce5
...
@@ -79,7 +79,7 @@ def get_batch_norm(batch_norm_type: Text) -> tf.keras.layers.BatchNormalization:
...
@@ -79,7 +79,7 @@ def get_batch_norm(batch_norm_type: Text) -> tf.keras.layers.BatchNormalization:
Args:
Args:
batch_norm_type: The type of batch normalization layer implementation. `tpu`
batch_norm_type: The type of batch normalization layer implementation. `tpu`
will use `TpuBatchNormalization`.
will use `TpuBatchNormalization`.
Returns:
Returns:
An instance of `tf.keras.layers.BatchNormalization`.
An instance of `tf.keras.layers.BatchNormalization`.
...
@@ -95,8 +95,10 @@ def count_params(model, trainable_only=True):
...
@@ -95,8 +95,10 @@ def count_params(model, trainable_only=True):
if
not
trainable_only
:
if
not
trainable_only
:
return
model
.
count_params
()
return
model
.
count_params
()
else
:
else
:
return
int
(
np
.
sum
([
tf
.
keras
.
backend
.
count_params
(
p
)
return
int
(
for
p
in
model
.
trainable_weights
]))
np
.
sum
([
tf
.
keras
.
backend
.
count_params
(
p
)
for
p
in
model
.
trainable_weights
]))
def
load_weights
(
model
:
tf
.
keras
.
Model
,
def
load_weights
(
model
:
tf
.
keras
.
Model
,
...
@@ -107,8 +109,8 @@ def load_weights(model: tf.keras.Model,
...
@@ -107,8 +109,8 @@ def load_weights(model: tf.keras.Model,
Args:
Args:
model: the model to load weights into
model: the model to load weights into
model_weights_path: the path of the model weights
model_weights_path: the path of the model weights
weights_format: the model weights format. One of 'saved_model', 'h5',
weights_format: the model weights format. One of 'saved_model', 'h5',
or
or
'checkpoint'.
'checkpoint'.
"""
"""
if
weights_format
==
'saved_model'
:
if
weights_format
==
'saved_model'
:
loaded_model
=
tf
.
keras
.
models
.
load_model
(
model_weights_path
)
loaded_model
=
tf
.
keras
.
models
.
load_model
(
model_weights_path
)
...
...
official/vision/image_classification/efficientnet/efficientnet_model.py
View file @
88253ce5
...
@@ -64,11 +64,11 @@ class ModelConfig(base_config.Config):
...
@@ -64,11 +64,11 @@ class ModelConfig(base_config.Config):
# (input_filters, output_filters, kernel_size, num_repeat,
# (input_filters, output_filters, kernel_size, num_repeat,
# expand_ratio, strides, se_ratio)
# expand_ratio, strides, se_ratio)
# pylint: disable=bad-whitespace
# pylint: disable=bad-whitespace
BlockConfig
.
from_args
(
32
,
16
,
3
,
1
,
1
,
(
1
,
1
),
0.25
),
BlockConfig
.
from_args
(
32
,
16
,
3
,
1
,
1
,
(
1
,
1
),
0.25
),
BlockConfig
.
from_args
(
16
,
24
,
3
,
2
,
6
,
(
2
,
2
),
0.25
),
BlockConfig
.
from_args
(
16
,
24
,
3
,
2
,
6
,
(
2
,
2
),
0.25
),
BlockConfig
.
from_args
(
24
,
40
,
5
,
2
,
6
,
(
2
,
2
),
0.25
),
BlockConfig
.
from_args
(
24
,
40
,
5
,
2
,
6
,
(
2
,
2
),
0.25
),
BlockConfig
.
from_args
(
40
,
80
,
3
,
3
,
6
,
(
2
,
2
),
0.25
),
BlockConfig
.
from_args
(
40
,
80
,
3
,
3
,
6
,
(
2
,
2
),
0.25
),
BlockConfig
.
from_args
(
80
,
112
,
5
,
3
,
6
,
(
1
,
1
),
0.25
),
BlockConfig
.
from_args
(
80
,
112
,
5
,
3
,
6
,
(
1
,
1
),
0.25
),
BlockConfig
.
from_args
(
112
,
192
,
5
,
4
,
6
,
(
2
,
2
),
0.25
),
BlockConfig
.
from_args
(
112
,
192
,
5
,
4
,
6
,
(
2
,
2
),
0.25
),
BlockConfig
.
from_args
(
192
,
320
,
3
,
1
,
6
,
(
1
,
1
),
0.25
),
BlockConfig
.
from_args
(
192
,
320
,
3
,
1
,
6
,
(
1
,
1
),
0.25
),
# pylint: enable=bad-whitespace
# pylint: enable=bad-whitespace
...
@@ -128,8 +128,7 @@ DENSE_KERNEL_INITIALIZER = {
...
@@ -128,8 +128,7 @@ DENSE_KERNEL_INITIALIZER = {
}
}
def
round_filters
(
filters
:
int
,
def
round_filters
(
filters
:
int
,
config
:
ModelConfig
)
->
int
:
config
:
ModelConfig
)
->
int
:
"""Round number of filters based on width coefficient."""
"""Round number of filters based on width coefficient."""
width_coefficient
=
config
.
width_coefficient
width_coefficient
=
config
.
width_coefficient
min_depth
=
config
.
min_depth
min_depth
=
config
.
min_depth
...
@@ -189,21 +188,24 @@ def conv2d_block(inputs: tf.Tensor,
...
@@ -189,21 +188,24 @@ def conv2d_block(inputs: tf.Tensor,
init_kwargs
.
update
({
'depthwise_initializer'
:
CONV_KERNEL_INITIALIZER
})
init_kwargs
.
update
({
'depthwise_initializer'
:
CONV_KERNEL_INITIALIZER
})
else
:
else
:
conv2d
=
tf
.
keras
.
layers
.
Conv2D
conv2d
=
tf
.
keras
.
layers
.
Conv2D
init_kwargs
.
update
({
'filters'
:
conv_filters
,
init_kwargs
.
update
({
'kernel_initializer'
:
CONV_KERNEL_INITIALIZER
})
'filters'
:
conv_filters
,
'kernel_initializer'
:
CONV_KERNEL_INITIALIZER
})
x
=
conv2d
(
**
init_kwargs
)(
inputs
)
x
=
conv2d
(
**
init_kwargs
)(
inputs
)
if
use_batch_norm
:
if
use_batch_norm
:
bn_axis
=
1
if
data_format
==
'channels_first'
else
-
1
bn_axis
=
1
if
data_format
==
'channels_first'
else
-
1
x
=
batch_norm
(
axis
=
bn_axis
,
x
=
batch_norm
(
momentum
=
bn_momentum
,
axis
=
bn_axis
,
epsilon
=
bn_epsilon
,
momentum
=
bn_momentum
,
name
=
name
+
'_bn'
)(
x
)
epsilon
=
bn_epsilon
,
name
=
name
+
'_bn'
)(
x
)
if
activation
is
not
None
:
if
activation
is
not
None
:
x
=
tf
.
keras
.
layers
.
Activation
(
activation
,
x
=
tf
.
keras
.
layers
.
Activation
(
activation
,
name
=
name
+
'_activation'
)(
x
)
name
=
name
+
'_activation'
)(
x
)
return
x
return
x
...
@@ -235,42 +237,43 @@ def mb_conv_block(inputs: tf.Tensor,
...
@@ -235,42 +237,43 @@ def mb_conv_block(inputs: tf.Tensor,
if
block
.
fused_conv
:
if
block
.
fused_conv
:
# If we use fused mbconv, skip expansion and use regular conv.
# If we use fused mbconv, skip expansion and use regular conv.
x
=
conv2d_block
(
x
,
x
=
conv2d_block
(
filters
,
x
,
config
,
filters
,
kernel_size
=
block
.
kernel_size
,
config
,
strides
=
block
.
strides
,
kernel_size
=
block
.
kernel_size
,
activation
=
activation
,
strides
=
block
.
strides
,
name
=
prefix
+
'fused'
)
activation
=
activation
,
name
=
prefix
+
'fused'
)
else
:
else
:
if
block
.
expand_ratio
!=
1
:
if
block
.
expand_ratio
!=
1
:
# Expansion phase
# Expansion phase
kernel_size
=
(
1
,
1
)
if
use_depthwise
else
(
3
,
3
)
kernel_size
=
(
1
,
1
)
if
use_depthwise
else
(
3
,
3
)
x
=
conv2d_block
(
x
,
x
=
conv2d_block
(
filters
,
x
,
config
,
filters
,
kernel_size
=
kernel_size
,
config
,
activation
=
activation
,
kernel_size
=
kernel_size
,
name
=
prefix
+
'expand'
)
activation
=
activation
,
name
=
prefix
+
'expand'
)
# Depthwise Convolution
# Depthwise Convolution
if
use_depthwise
:
if
use_depthwise
:
x
=
conv2d_block
(
x
,
x
=
conv2d_block
(
conv_filters
=
None
,
x
,
config
=
config
,
conv_filters
=
None
,
kernel_size
=
block
.
kernel_size
,
config
=
config
,
strides
=
block
.
strides
,
kernel_size
=
block
.
kernel_size
,
activation
=
activation
,
strides
=
block
.
strides
,
depthwise
=
True
,
activation
=
activation
,
name
=
prefix
+
'depthwise'
)
depthwise
=
True
,
name
=
prefix
+
'depthwise'
)
# Squeeze and Excitation phase
# Squeeze and Excitation phase
if
use_se
:
if
use_se
:
assert
block
.
se_ratio
is
not
None
assert
block
.
se_ratio
is
not
None
assert
0
<
block
.
se_ratio
<=
1
assert
0
<
block
.
se_ratio
<=
1
num_reduced_filters
=
max
(
1
,
int
(
num_reduced_filters
=
max
(
1
,
int
(
block
.
input_filters
*
block
.
se_ratio
))
block
.
input_filters
*
block
.
se_ratio
))
if
data_format
==
'channels_first'
:
if
data_format
==
'channels_first'
:
se_shape
=
(
filters
,
1
,
1
)
se_shape
=
(
filters
,
1
,
1
)
...
@@ -280,53 +283,51 @@ def mb_conv_block(inputs: tf.Tensor,
...
@@ -280,53 +283,51 @@ def mb_conv_block(inputs: tf.Tensor,
se
=
tf
.
keras
.
layers
.
GlobalAveragePooling2D
(
name
=
prefix
+
'se_squeeze'
)(
x
)
se
=
tf
.
keras
.
layers
.
GlobalAveragePooling2D
(
name
=
prefix
+
'se_squeeze'
)(
x
)
se
=
tf
.
keras
.
layers
.
Reshape
(
se_shape
,
name
=
prefix
+
'se_reshape'
)(
se
)
se
=
tf
.
keras
.
layers
.
Reshape
(
se_shape
,
name
=
prefix
+
'se_reshape'
)(
se
)
se
=
conv2d_block
(
se
,
se
=
conv2d_block
(
num_reduced_filters
,
se
,
config
,
num_reduced_filters
,
use_bias
=
True
,
config
,
use_batch_norm
=
False
,
use_bias
=
True
,
activation
=
activation
,
use_batch_norm
=
False
,
name
=
prefix
+
'se_reduce'
)
activation
=
activation
,
se
=
conv2d_block
(
se
,
name
=
prefix
+
'se_reduce'
)
filters
,
se
=
conv2d_block
(
config
,
se
,
use_bias
=
True
,
filters
,
use_batch_norm
=
False
,
config
,
activation
=
'sigmoid'
,
use_bias
=
True
,
name
=
prefix
+
'se_expand'
)
use_batch_norm
=
False
,
activation
=
'sigmoid'
,
name
=
prefix
+
'se_expand'
)
x
=
tf
.
keras
.
layers
.
multiply
([
x
,
se
],
name
=
prefix
+
'se_excite'
)
x
=
tf
.
keras
.
layers
.
multiply
([
x
,
se
],
name
=
prefix
+
'se_excite'
)
# Output phase
# Output phase
x
=
conv2d_block
(
x
,
x
=
conv2d_block
(
block
.
output_filters
,
x
,
block
.
output_filters
,
config
,
activation
=
None
,
name
=
prefix
+
'project'
)
config
,
activation
=
None
,
name
=
prefix
+
'project'
)
# Add identity so that quantization-aware training can insert quantization
# Add identity so that quantization-aware training can insert quantization
# ops correctly.
# ops correctly.
x
=
tf
.
keras
.
layers
.
Activation
(
tf_utils
.
get_activation
(
'identity'
),
x
=
tf
.
keras
.
layers
.
Activation
(
name
=
prefix
+
'id'
)(
x
)
tf_utils
.
get_activation
(
'identity'
),
name
=
prefix
+
'id'
)(
x
)
if
(
block
.
id_skip
if
(
block
.
id_skip
and
all
(
s
==
1
for
s
in
block
.
strides
)
and
and
all
(
s
==
1
for
s
in
block
.
strides
)
block
.
input_filters
==
block
.
output_filters
):
and
block
.
input_filters
==
block
.
output_filters
):
if
drop_connect_rate
and
drop_connect_rate
>
0
:
if
drop_connect_rate
and
drop_connect_rate
>
0
:
# Apply dropconnect
# Apply dropconnect
# The only difference between dropout and dropconnect in TF is scaling by
# The only difference between dropout and dropconnect in TF is scaling by
# drop_connect_rate during training. See:
# drop_connect_rate during training. See:
# https://github.com/keras-team/keras/pull/9898#issuecomment-380577612
# https://github.com/keras-team/keras/pull/9898#issuecomment-380577612
x
=
tf
.
keras
.
layers
.
Dropout
(
drop_connect_rate
,
x
=
tf
.
keras
.
layers
.
Dropout
(
noise_shape
=
(
None
,
1
,
1
,
1
),
drop_connect_rate
,
noise_shape
=
(
None
,
1
,
1
,
1
),
name
=
prefix
+
'drop'
)(
name
=
prefix
+
'drop'
)(
x
)
x
)
x
=
tf
.
keras
.
layers
.
add
([
x
,
inputs
],
name
=
prefix
+
'add'
)
x
=
tf
.
keras
.
layers
.
add
([
x
,
inputs
],
name
=
prefix
+
'add'
)
return
x
return
x
def
efficientnet
(
image_input
:
tf
.
keras
.
layers
.
Input
,
def
efficientnet
(
image_input
:
tf
.
keras
.
layers
.
Input
,
config
:
ModelConfig
):
config
:
ModelConfig
):
"""Creates an EfficientNet graph given the model parameters.
"""Creates an EfficientNet graph given the model parameters.
This function is wrapped by the `EfficientNet` class to make a tf.keras.Model.
This function is wrapped by the `EfficientNet` class to make a tf.keras.Model.
...
@@ -357,19 +358,18 @@ def efficientnet(image_input: tf.keras.layers.Input,
...
@@ -357,19 +358,18 @@ def efficientnet(image_input: tf.keras.layers.Input,
# Happens on GPU/TPU if available.
# Happens on GPU/TPU if available.
x
=
tf
.
keras
.
layers
.
Permute
((
3
,
1
,
2
))(
x
)
x
=
tf
.
keras
.
layers
.
Permute
((
3
,
1
,
2
))(
x
)
if
rescale_input
:
if
rescale_input
:
x
=
preprocessing
.
normalize_images
(
x
,
x
=
preprocessing
.
normalize_images
(
num_channels
=
input_channels
,
x
,
num_channels
=
input_channels
,
dtype
=
dtype
,
data_format
=
data_format
)
dtype
=
dtype
,
data_format
=
data_format
)
# Build stem
# Build stem
x
=
conv2d_block
(
x
,
x
=
conv2d_block
(
round_filters
(
stem_base_filters
,
config
),
x
,
config
,
round_filters
(
stem_base_filters
,
config
),
kernel_size
=
[
3
,
3
],
config
,
strides
=
[
2
,
2
],
kernel_size
=
[
3
,
3
],
activation
=
activation
,
strides
=
[
2
,
2
],
name
=
'stem'
)
activation
=
activation
,
name
=
'stem'
)
# Build blocks
# Build blocks
num_blocks_total
=
sum
(
num_blocks_total
=
sum
(
...
@@ -391,10 +391,7 @@ def efficientnet(image_input: tf.keras.layers.Input,
...
@@ -391,10 +391,7 @@ def efficientnet(image_input: tf.keras.layers.Input,
x
=
mb_conv_block
(
x
,
block
,
config
,
block_prefix
)
x
=
mb_conv_block
(
x
,
block
,
config
,
block_prefix
)
block_num
+=
1
block_num
+=
1
if
block
.
num_repeat
>
1
:
if
block
.
num_repeat
>
1
:
block
=
block
.
replace
(
block
=
block
.
replace
(
input_filters
=
block
.
output_filters
,
strides
=
[
1
,
1
])
input_filters
=
block
.
output_filters
,
strides
=
[
1
,
1
]
)
for
block_idx
in
range
(
block
.
num_repeat
-
1
):
for
block_idx
in
range
(
block
.
num_repeat
-
1
):
drop_rate
=
drop_connect_rate
*
float
(
block_num
)
/
num_blocks_total
drop_rate
=
drop_connect_rate
*
float
(
block_num
)
/
num_blocks_total
...
@@ -404,11 +401,12 @@ def efficientnet(image_input: tf.keras.layers.Input,
...
@@ -404,11 +401,12 @@ def efficientnet(image_input: tf.keras.layers.Input,
block_num
+=
1
block_num
+=
1
# Build top
# Build top
x
=
conv2d_block
(
x
,
x
=
conv2d_block
(
round_filters
(
top_base_filters
,
config
),
x
,
config
,
round_filters
(
top_base_filters
,
config
),
activation
=
activation
,
config
,
name
=
'top'
)
activation
=
activation
,
name
=
'top'
)
# Build classifier
# Build classifier
x
=
tf
.
keras
.
layers
.
GlobalAveragePooling2D
(
name
=
'top_pool'
)(
x
)
x
=
tf
.
keras
.
layers
.
GlobalAveragePooling2D
(
name
=
'top_pool'
)(
x
)
...
@@ -419,7 +417,8 @@ def efficientnet(image_input: tf.keras.layers.Input,
...
@@ -419,7 +417,8 @@ def efficientnet(image_input: tf.keras.layers.Input,
kernel_initializer
=
DENSE_KERNEL_INITIALIZER
,
kernel_initializer
=
DENSE_KERNEL_INITIALIZER
,
kernel_regularizer
=
tf
.
keras
.
regularizers
.
l2
(
weight_decay
),
kernel_regularizer
=
tf
.
keras
.
regularizers
.
l2
(
weight_decay
),
bias_regularizer
=
tf
.
keras
.
regularizers
.
l2
(
weight_decay
),
bias_regularizer
=
tf
.
keras
.
regularizers
.
l2
(
weight_decay
),
name
=
'logits'
)(
x
)
name
=
'logits'
)(
x
)
x
=
tf
.
keras
.
layers
.
Activation
(
'softmax'
,
name
=
'probs'
)(
x
)
x
=
tf
.
keras
.
layers
.
Activation
(
'softmax'
,
name
=
'probs'
)(
x
)
return
x
return
x
...
@@ -439,8 +438,7 @@ class EfficientNet(tf.keras.Model):
...
@@ -439,8 +438,7 @@ class EfficientNet(tf.keras.Model):
Args:
Args:
config: (optional) the main model parameters to create the model
config: (optional) the main model parameters to create the model
overrides: (optional) a dict containing keys that can override
overrides: (optional) a dict containing keys that can override config
config
"""
"""
overrides
=
overrides
or
{}
overrides
=
overrides
or
{}
config
=
config
or
ModelConfig
()
config
=
config
or
ModelConfig
()
...
@@ -457,9 +455,7 @@ class EfficientNet(tf.keras.Model):
...
@@ -457,9 +455,7 @@ class EfficientNet(tf.keras.Model):
# Cast to float32 in case we have a different model dtype
# Cast to float32 in case we have a different model dtype
output
=
tf
.
cast
(
output
,
tf
.
float32
)
output
=
tf
.
cast
(
output
,
tf
.
float32
)
logging
.
info
(
'Building model %s with params %s'
,
logging
.
info
(
'Building model %s with params %s'
,
model_name
,
self
.
config
)
model_name
,
self
.
config
)
super
(
EfficientNet
,
self
).
__init__
(
super
(
EfficientNet
,
self
).
__init__
(
inputs
=
image_input
,
outputs
=
output
,
name
=
model_name
)
inputs
=
image_input
,
outputs
=
output
,
name
=
model_name
)
...
@@ -477,8 +473,8 @@ class EfficientNet(tf.keras.Model):
...
@@ -477,8 +473,8 @@ class EfficientNet(tf.keras.Model):
Args:
Args:
model_name: the predefined model name
model_name: the predefined model name
model_weights_path: the path to the weights (h5 file or saved model dir)
model_weights_path: the path to the weights (h5 file or saved model dir)
weights_format: the model weights format. One of 'saved_model', 'h5',
weights_format: the model weights format. One of 'saved_model', 'h5',
or
or
'checkpoint'.
'checkpoint'.
overrides: (optional) a dict containing keys that can override config
overrides: (optional) a dict containing keys that can override config
Returns:
Returns:
...
@@ -498,8 +494,7 @@ class EfficientNet(tf.keras.Model):
...
@@ -498,8 +494,7 @@ class EfficientNet(tf.keras.Model):
model
=
cls
(
config
=
config
,
overrides
=
overrides
)
model
=
cls
(
config
=
config
,
overrides
=
overrides
)
if
model_weights_path
:
if
model_weights_path
:
common_modules
.
load_weights
(
model
,
common_modules
.
load_weights
(
model_weights_path
,
model
,
model_weights_path
,
weights_format
=
weights_format
)
weights_format
=
weights_format
)
return
model
return
model
official/vision/image_classification/efficientnet/tfhub_export.py
View file @
88253ce5
...
@@ -30,10 +30,8 @@ from official.vision.image_classification.efficientnet import efficientnet_model
...
@@ -30,10 +30,8 @@ from official.vision.image_classification.efficientnet import efficientnet_model
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
"model_name"
,
None
,
flags
.
DEFINE_string
(
"model_name"
,
None
,
"EfficientNet model name."
)
"EfficientNet model name."
)
flags
.
DEFINE_string
(
"model_path"
,
None
,
"File path to TF model checkpoint."
)
flags
.
DEFINE_string
(
"model_path"
,
None
,
"File path to TF model checkpoint."
)
flags
.
DEFINE_string
(
"export_path"
,
None
,
flags
.
DEFINE_string
(
"export_path"
,
None
,
"TF-Hub SavedModel destination path to export."
)
"TF-Hub SavedModel destination path to export."
)
...
@@ -65,5 +63,6 @@ def main(argv):
...
@@ -65,5 +63,6 @@ def main(argv):
export_tfhub
(
FLAGS
.
model_path
,
FLAGS
.
export_path
,
FLAGS
.
model_name
)
export_tfhub
(
FLAGS
.
model_path
,
FLAGS
.
export_path
,
FLAGS
.
model_name
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
app
.
run
(
main
)
app
.
run
(
main
)
official/vision/image_classification/learning_rate.py
View file @
88253ce5
...
@@ -29,11 +29,10 @@ BASE_LEARNING_RATE = 0.1
...
@@ -29,11 +29,10 @@ BASE_LEARNING_RATE = 0.1
class
WarmupDecaySchedule
(
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
):
class
WarmupDecaySchedule
(
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
):
"""A wrapper for LearningRateSchedule that includes warmup steps."""
"""A wrapper for LearningRateSchedule that includes warmup steps."""
def
__init__
(
def
__init__
(
self
,
self
,
lr_schedule
:
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
,
lr_schedule
:
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
,
warmup_steps
:
int
,
warmup_steps
:
int
,
warmup_lr
:
Optional
[
float
]
=
None
):
warmup_lr
:
Optional
[
float
]
=
None
):
"""Add warmup decay to a learning rate schedule.
"""Add warmup decay to a learning rate schedule.
Args:
Args:
...
@@ -42,7 +41,6 @@ class WarmupDecaySchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
...
@@ -42,7 +41,6 @@ class WarmupDecaySchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
warmup_lr: an optional field for the final warmup learning rate. This
warmup_lr: an optional field for the final warmup learning rate. This
should be provided if the base `lr_schedule` does not contain this
should be provided if the base `lr_schedule` does not contain this
field.
field.
"""
"""
super
(
WarmupDecaySchedule
,
self
).
__init__
()
super
(
WarmupDecaySchedule
,
self
).
__init__
()
self
.
_lr_schedule
=
lr_schedule
self
.
_lr_schedule
=
lr_schedule
...
@@ -63,8 +61,7 @@ class WarmupDecaySchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
...
@@ -63,8 +61,7 @@ class WarmupDecaySchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
global_step_recomp
=
tf
.
cast
(
step
,
dtype
)
global_step_recomp
=
tf
.
cast
(
step
,
dtype
)
warmup_steps
=
tf
.
cast
(
self
.
_warmup_steps
,
dtype
)
warmup_steps
=
tf
.
cast
(
self
.
_warmup_steps
,
dtype
)
warmup_lr
=
initial_learning_rate
*
global_step_recomp
/
warmup_steps
warmup_lr
=
initial_learning_rate
*
global_step_recomp
/
warmup_steps
lr
=
tf
.
cond
(
global_step_recomp
<
warmup_steps
,
lr
=
tf
.
cond
(
global_step_recomp
<
warmup_steps
,
lambda
:
warmup_lr
,
lambda
:
warmup_lr
,
lambda
:
lr
)
lambda
:
lr
)
return
lr
return
lr
...
...
official/vision/image_classification/learning_rate_test.py
View file @
88253ce5
...
@@ -37,14 +37,13 @@ class LearningRateTests(tf.test.TestCase):
...
@@ -37,14 +37,13 @@ class LearningRateTests(tf.test.TestCase):
decay_steps
=
decay_steps
,
decay_steps
=
decay_steps
,
decay_rate
=
decay_rate
)
decay_rate
=
decay_rate
)
lr
=
learning_rate
.
WarmupDecaySchedule
(
lr
=
learning_rate
.
WarmupDecaySchedule
(
lr_schedule
=
base_lr
,
lr_schedule
=
base_lr
,
warmup_steps
=
warmup_steps
)
warmup_steps
=
warmup_steps
)
for
step
in
range
(
warmup_steps
-
1
):
for
step
in
range
(
warmup_steps
-
1
):
config
=
lr
.
get_config
()
config
=
lr
.
get_config
()
self
.
assertEqual
(
config
[
'warmup_steps'
],
warmup_steps
)
self
.
assertEqual
(
config
[
'warmup_steps'
],
warmup_steps
)
self
.
assertAllClose
(
self
.
evaluate
(
lr
(
step
)),
self
.
assertAllClose
(
step
/
warmup_steps
*
initial_lr
)
self
.
evaluate
(
lr
(
step
)),
step
/
warmup_steps
*
initial_lr
)
def
test_cosine_decay_with_warmup
(
self
):
def
test_cosine_decay_with_warmup
(
self
):
"""Basic computational test for cosine decay with warmup."""
"""Basic computational test for cosine decay with warmup."""
...
...
official/vision/image_classification/mnist_main.py
View file @
88253ce5
...
@@ -19,6 +19,7 @@ from __future__ import print_function
...
@@ -19,6 +19,7 @@ from __future__ import print_function
import
os
import
os
# Import libraries
from
absl
import
app
from
absl
import
app
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
from
absl
import
logging
...
...
official/vision/image_classification/mnist_test.py
View file @
88253ce5
...
@@ -58,7 +58,8 @@ class KerasMnistTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -58,7 +58,8 @@ class KerasMnistTest(tf.test.TestCase, parameterized.TestCase):
"""Test Keras MNIST model with `strategy`."""
"""Test Keras MNIST model with `strategy`."""
extra_flags
=
[
extra_flags
=
[
"-train_epochs"
,
"1"
,
"-train_epochs"
,
"1"
,
# Let TFDS find the metadata folder automatically
# Let TFDS find the metadata folder automatically
"--data_dir="
"--data_dir="
]
]
...
@@ -72,9 +73,10 @@ class KerasMnistTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -72,9 +73,10 @@ class KerasMnistTest(tf.test.TestCase, parameterized.TestCase):
tf
.
data
.
Dataset
.
from_tensor_slices
(
dummy_data
),
tf
.
data
.
Dataset
.
from_tensor_slices
(
dummy_data
),
)
)
run
=
functools
.
partial
(
mnist_main
.
run
,
run
=
functools
.
partial
(
datasets_override
=
datasets
,
mnist_main
.
run
,
strategy_override
=
distribution
)
datasets_override
=
datasets
,
strategy_override
=
distribution
)
integration
.
run_synthetic
(
integration
.
run_synthetic
(
main
=
run
,
main
=
run
,
...
...
official/vision/image_classification/optimizer_factory.py
View file @
88253ce5
...
@@ -65,19 +65,19 @@ class MovingAverage(tf.keras.optimizers.Optimizer):
...
@@ -65,19 +65,19 @@ class MovingAverage(tf.keras.optimizers.Optimizer):
"""Construct a new MovingAverage optimizer.
"""Construct a new MovingAverage optimizer.
Args:
Args:
optimizer: `tf.keras.optimizers.Optimizer` that will be
optimizer: `tf.keras.optimizers.Optimizer` that will be
used to compute
used to compute
and apply gradients.
and apply gradients.
average_decay: float. Decay to use to maintain the moving averages
average_decay: float. Decay to use to maintain the moving averages
of
of
trained variables.
trained variables.
start_step: int. What step to start the moving average.
start_step: int. What step to start the moving average.
dynamic_decay: bool. Whether to change the decay based on the number
dynamic_decay: bool. Whether to change the decay based on the number
of
of
optimizer updates. Decay will start at 0.1 and gradually increase
optimizer updates. Decay will start at 0.1 and gradually increase
up to
up to
`average_decay` after each optimizer update. This behavior is
`average_decay` after each optimizer update. This behavior is
similar to
similar to
`tf.train.ExponentialMovingAverage` in TF 1.x.
`tf.train.ExponentialMovingAverage` in TF 1.x.
name: Optional name for the operations created when applying
name: Optional name for the operations created when applying
gradients.
gradients.
Defaults to "moving_average".
Defaults to "moving_average".
**kwargs: keyword arguments. Allowed to be {`clipnorm`,
**kwargs: keyword arguments. Allowed to be {`clipnorm`,
`clipvalue`, `lr`,
`clipvalue`, `lr`,
`decay`}.
`decay`}.
"""
"""
super
(
MovingAverage
,
self
).
__init__
(
name
,
**
kwargs
)
super
(
MovingAverage
,
self
).
__init__
(
name
,
**
kwargs
)
self
.
_optimizer
=
optimizer
self
.
_optimizer
=
optimizer
...
@@ -128,8 +128,8 @@ class MovingAverage(tf.keras.optimizers.Optimizer):
...
@@ -128,8 +128,8 @@ class MovingAverage(tf.keras.optimizers.Optimizer):
strategy
.
extended
.
update
(
v_moving
,
_apply_moving
,
args
=
(
v_normal
,))
strategy
.
extended
.
update
(
v_moving
,
_apply_moving
,
args
=
(
v_normal
,))
ctx
=
tf
.
distribute
.
get_replica_context
()
ctx
=
tf
.
distribute
.
get_replica_context
()
return
ctx
.
merge_call
(
_update
,
args
=
(
zip
(
self
.
_average_weights
,
return
ctx
.
merge_call
(
self
.
_model_weights
),))
_update
,
args
=
(
zip
(
self
.
_average_weights
,
self
.
_model_weights
),))
def
swap_weights
(
self
):
def
swap_weights
(
self
):
"""Swap the average and moving weights.
"""Swap the average and moving weights.
...
@@ -148,12 +148,15 @@ class MovingAverage(tf.keras.optimizers.Optimizer):
...
@@ -148,12 +148,15 @@ class MovingAverage(tf.keras.optimizers.Optimizer):
@
tf
.
function
@
tf
.
function
def
_swap_weights
(
self
):
def
_swap_weights
(
self
):
def
fn_0
(
a
,
b
):
def
fn_0
(
a
,
b
):
a
.
assign_add
(
b
)
a
.
assign_add
(
b
)
return
a
return
a
def
fn_1
(
b
,
a
):
def
fn_1
(
b
,
a
):
b
.
assign
(
a
-
b
)
b
.
assign
(
a
-
b
)
return
b
return
b
def
fn_2
(
a
,
b
):
def
fn_2
(
a
,
b
):
a
.
assign_sub
(
b
)
a
.
assign_sub
(
b
)
return
a
return
a
...
@@ -174,12 +177,14 @@ class MovingAverage(tf.keras.optimizers.Optimizer):
...
@@ -174,12 +177,14 @@ class MovingAverage(tf.keras.optimizers.Optimizer):
Args:
Args:
var_list: List of model variables to be assigned to their average.
var_list: List of model variables to be assigned to their average.
Returns:
Returns:
assign_op: The op corresponding to the assignment operation of
assign_op: The op corresponding to the assignment operation of
variables to their average.
variables to their average.
"""
"""
assign_op
=
tf
.
group
([
assign_op
=
tf
.
group
([
var
.
assign
(
self
.
get_slot
(
var
,
'average'
))
for
var
in
var_list
var
.
assign
(
self
.
get_slot
(
var
,
'average'
))
for
var
in
var_list
if
var
.
trainable
if
var
.
trainable
])
])
return
assign_op
return
assign_op
...
@@ -256,13 +261,13 @@ def build_optimizer(
...
@@ -256,13 +261,13 @@ def build_optimizer(
"""Build the optimizer based on name.
"""Build the optimizer based on name.
Args:
Args:
optimizer_name: String representation of the optimizer name. Examples:
optimizer_name: String representation of the optimizer name. Examples:
sgd,
sgd,
momentum, rmsprop.
momentum, rmsprop.
base_learning_rate: `tf.keras.optimizers.schedules.LearningRateSchedule`
base_learning_rate: `tf.keras.optimizers.schedules.LearningRateSchedule`
base learning rate.
base learning rate.
params: String -> Any dictionary representing the optimizer params.
params: String -> Any dictionary representing the optimizer params.
This
This
should contain optimizer specific parameters such as
should contain optimizer specific parameters such as
`base_learning_rate`,
`base_learning_rate`,
`decay`, etc.
`decay`, etc.
model: The `tf.keras.Model`. This is used for the shadow copy if using
model: The `tf.keras.Model`. This is used for the shadow copy if using
`MovingAverage`.
`MovingAverage`.
...
@@ -279,43 +284,47 @@ def build_optimizer(
...
@@ -279,43 +284,47 @@ def build_optimizer(
if
optimizer_name
==
'sgd'
:
if
optimizer_name
==
'sgd'
:
logging
.
info
(
'Using SGD optimizer'
)
logging
.
info
(
'Using SGD optimizer'
)
nesterov
=
params
.
get
(
'nesterov'
,
False
)
nesterov
=
params
.
get
(
'nesterov'
,
False
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
learning_rate
=
base_learning_rate
,
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
nesterov
=
nesterov
)
learning_rate
=
base_learning_rate
,
nesterov
=
nesterov
)
elif
optimizer_name
==
'momentum'
:
elif
optimizer_name
==
'momentum'
:
logging
.
info
(
'Using momentum optimizer'
)
logging
.
info
(
'Using momentum optimizer'
)
nesterov
=
params
.
get
(
'nesterov'
,
False
)
nesterov
=
params
.
get
(
'nesterov'
,
False
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
learning_rate
=
base_learning_rate
,
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
momentum
=
params
[
'momentum'
],
learning_rate
=
base_learning_rate
,
nesterov
=
nesterov
)
momentum
=
params
[
'momentum'
],
nesterov
=
nesterov
)
elif
optimizer_name
==
'rmsprop'
:
elif
optimizer_name
==
'rmsprop'
:
logging
.
info
(
'Using RMSProp'
)
logging
.
info
(
'Using RMSProp'
)
rho
=
params
.
get
(
'decay'
,
None
)
or
params
.
get
(
'rho'
,
0.9
)
rho
=
params
.
get
(
'decay'
,
None
)
or
params
.
get
(
'rho'
,
0.9
)
momentum
=
params
.
get
(
'momentum'
,
0.9
)
momentum
=
params
.
get
(
'momentum'
,
0.9
)
epsilon
=
params
.
get
(
'epsilon'
,
1e-07
)
epsilon
=
params
.
get
(
'epsilon'
,
1e-07
)
optimizer
=
tf
.
keras
.
optimizers
.
RMSprop
(
learning_rate
=
base_learning_rate
,
optimizer
=
tf
.
keras
.
optimizers
.
RMSprop
(
rho
=
rho
,
learning_rate
=
base_learning_rate
,
momentum
=
momentum
,
rho
=
rho
,
epsilon
=
epsilon
)
momentum
=
momentum
,
epsilon
=
epsilon
)
elif
optimizer_name
==
'adam'
:
elif
optimizer_name
==
'adam'
:
logging
.
info
(
'Using Adam'
)
logging
.
info
(
'Using Adam'
)
beta_1
=
params
.
get
(
'beta_1'
,
0.9
)
beta_1
=
params
.
get
(
'beta_1'
,
0.9
)
beta_2
=
params
.
get
(
'beta_2'
,
0.999
)
beta_2
=
params
.
get
(
'beta_2'
,
0.999
)
epsilon
=
params
.
get
(
'epsilon'
,
1e-07
)
epsilon
=
params
.
get
(
'epsilon'
,
1e-07
)
optimizer
=
tf
.
keras
.
optimizers
.
Adam
(
learning_rate
=
base_learning_rate
,
optimizer
=
tf
.
keras
.
optimizers
.
Adam
(
beta_1
=
beta_1
,
learning_rate
=
base_learning_rate
,
beta_2
=
beta_2
,
beta_1
=
beta_1
,
epsilon
=
epsilon
)
beta_2
=
beta_2
,
epsilon
=
epsilon
)
elif
optimizer_name
==
'adamw'
:
elif
optimizer_name
==
'adamw'
:
logging
.
info
(
'Using AdamW'
)
logging
.
info
(
'Using AdamW'
)
weight_decay
=
params
.
get
(
'weight_decay'
,
0.01
)
weight_decay
=
params
.
get
(
'weight_decay'
,
0.01
)
beta_1
=
params
.
get
(
'beta_1'
,
0.9
)
beta_1
=
params
.
get
(
'beta_1'
,
0.9
)
beta_2
=
params
.
get
(
'beta_2'
,
0.999
)
beta_2
=
params
.
get
(
'beta_2'
,
0.999
)
epsilon
=
params
.
get
(
'epsilon'
,
1e-07
)
epsilon
=
params
.
get
(
'epsilon'
,
1e-07
)
optimizer
=
tfa
.
optimizers
.
AdamW
(
weight_decay
=
weight_decay
,
optimizer
=
tfa
.
optimizers
.
AdamW
(
learning_rate
=
base_learning_rate
,
weight_decay
=
weight_decay
,
beta_1
=
beta_1
,
learning_rate
=
base_learning_rate
,
beta_2
=
beta_2
,
beta_1
=
beta_1
,
epsilon
=
epsilon
)
beta_2
=
beta_2
,
epsilon
=
epsilon
)
else
:
else
:
raise
ValueError
(
'Unknown optimizer %s'
%
optimizer_name
)
raise
ValueError
(
'Unknown optimizer %s'
%
optimizer_name
)
...
@@ -330,8 +339,7 @@ def build_optimizer(
...
@@ -330,8 +339,7 @@ def build_optimizer(
raise
ValueError
(
'`model` must be provided if using `MovingAverage`.'
)
raise
ValueError
(
'`model` must be provided if using `MovingAverage`.'
)
logging
.
info
(
'Including moving average decay.'
)
logging
.
info
(
'Including moving average decay.'
)
optimizer
=
MovingAverage
(
optimizer
=
MovingAverage
(
optimizer
=
optimizer
,
optimizer
=
optimizer
,
average_decay
=
moving_average_decay
)
average_decay
=
moving_average_decay
)
optimizer
.
shadow_copy
(
model
)
optimizer
.
shadow_copy
(
model
)
return
optimizer
return
optimizer
...
@@ -358,13 +366,15 @@ def build_learning_rate(params: base_configs.LearningRateConfig,
...
@@ -358,13 +366,15 @@ def build_learning_rate(params: base_configs.LearningRateConfig,
if
lr_multiplier
and
lr_multiplier
>
0
:
if
lr_multiplier
and
lr_multiplier
>
0
:
# Scale the learning rate based on the batch size and a multiplier
# Scale the learning rate based on the batch size and a multiplier
base_lr
*=
lr_multiplier
*
batch_size
base_lr
*=
lr_multiplier
*
batch_size
logging
.
info
(
'Scaling the learning rate based on the batch size '
logging
.
info
(
'multiplier. New base_lr: %f'
,
base_lr
)
'Scaling the learning rate based on the batch size '
'multiplier. New base_lr: %f'
,
base_lr
)
if
decay_type
==
'exponential'
:
if
decay_type
==
'exponential'
:
logging
.
info
(
'Using exponential learning rate with: '
logging
.
info
(
'initial_learning_rate: %f, decay_steps: %d, '
'Using exponential learning rate with: '
'decay_rate: %f'
,
base_lr
,
decay_steps
,
decay_rate
)
'initial_learning_rate: %f, decay_steps: %d, '
'decay_rate: %f'
,
base_lr
,
decay_steps
,
decay_rate
)
lr
=
tf
.
keras
.
optimizers
.
schedules
.
ExponentialDecay
(
lr
=
tf
.
keras
.
optimizers
.
schedules
.
ExponentialDecay
(
initial_learning_rate
=
base_lr
,
initial_learning_rate
=
base_lr
,
decay_steps
=
decay_steps
,
decay_steps
=
decay_steps
,
...
@@ -374,12 +384,11 @@ def build_learning_rate(params: base_configs.LearningRateConfig,
...
@@ -374,12 +384,11 @@ def build_learning_rate(params: base_configs.LearningRateConfig,
steps_per_epoch
=
params
.
examples_per_epoch
//
batch_size
steps_per_epoch
=
params
.
examples_per_epoch
//
batch_size
boundaries
=
[
boundary
*
steps_per_epoch
for
boundary
in
params
.
boundaries
]
boundaries
=
[
boundary
*
steps_per_epoch
for
boundary
in
params
.
boundaries
]
multipliers
=
[
batch_size
*
multiplier
for
multiplier
in
params
.
multipliers
]
multipliers
=
[
batch_size
*
multiplier
for
multiplier
in
params
.
multipliers
]
logging
.
info
(
'Using stepwise learning rate. Parameters: '
logging
.
info
(
'boundaries: %s, values: %s'
,
'Using stepwise learning rate. Parameters: '
boundaries
,
multipliers
)
'boundaries: %s, values: %s'
,
boundaries
,
multipliers
)
lr
=
tf
.
keras
.
optimizers
.
schedules
.
PiecewiseConstantDecay
(
lr
=
tf
.
keras
.
optimizers
.
schedules
.
PiecewiseConstantDecay
(
boundaries
=
boundaries
,
boundaries
=
boundaries
,
values
=
multipliers
)
values
=
multipliers
)
elif
decay_type
==
'cosine_with_warmup'
:
elif
decay_type
==
'cosine_with_warmup'
:
lr
=
learning_rate
.
CosineDecayWithWarmup
(
lr
=
learning_rate
.
CosineDecayWithWarmup
(
batch_size
=
batch_size
,
batch_size
=
batch_size
,
...
@@ -389,7 +398,6 @@ def build_learning_rate(params: base_configs.LearningRateConfig,
...
@@ -389,7 +398,6 @@ def build_learning_rate(params: base_configs.LearningRateConfig,
if
decay_type
not
in
[
'cosine_with_warmup'
]:
if
decay_type
not
in
[
'cosine_with_warmup'
]:
logging
.
info
(
'Applying %d warmup steps to the learning rate'
,
logging
.
info
(
'Applying %d warmup steps to the learning rate'
,
warmup_steps
)
warmup_steps
)
lr
=
learning_rate
.
WarmupDecaySchedule
(
lr
,
lr
=
learning_rate
.
WarmupDecaySchedule
(
warmup_steps
,
lr
,
warmup_steps
,
warmup_lr
=
base_lr
)
warmup_lr
=
base_lr
)
return
lr
return
lr
official/vision/image_classification/optimizer_factory_test.py
View file @
88253ce5
...
@@ -35,10 +35,8 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -35,10 +35,8 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
return
model
return
model
@
parameterized
.
named_parameters
(
@
parameterized
.
named_parameters
(
(
'sgd'
,
'sgd'
,
0.
,
False
),
(
'sgd'
,
'sgd'
,
0.
,
False
),
(
'momentum'
,
'momentum'
,
0.
,
False
),
(
'momentum'
,
'momentum'
,
0.
,
False
),
(
'rmsprop'
,
'rmsprop'
,
0.
,
False
),
(
'adam'
,
'adam'
,
0.
,
False
),
(
'rmsprop'
,
'rmsprop'
,
0.
,
False
),
(
'adam'
,
'adam'
,
0.
,
False
),
(
'adamw'
,
'adamw'
,
0.
,
False
),
(
'adamw'
,
'adamw'
,
0.
,
False
),
(
'momentum_lookahead'
,
'momentum'
,
0.
,
True
),
(
'momentum_lookahead'
,
'momentum'
,
0.
,
True
),
(
'sgd_ema'
,
'sgd'
,
0.999
,
False
),
(
'sgd_ema'
,
'sgd'
,
0.999
,
False
),
...
@@ -84,16 +82,13 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -84,16 +82,13 @@ class OptimizerFactoryTest(tf.test.TestCase, parameterized.TestCase):
train_steps
=
1
train_steps
=
1
lr
=
optimizer_factory
.
build_learning_rate
(
lr
=
optimizer_factory
.
build_learning_rate
(
params
=
params
,
params
=
params
,
batch_size
=
batch_size
,
train_steps
=
train_steps
)
batch_size
=
batch_size
,
train_steps
=
train_steps
)
self
.
assertTrue
(
self
.
assertTrue
(
issubclass
(
issubclass
(
type
(
lr
),
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
))
type
(
lr
),
tf
.
keras
.
optimizers
.
schedules
.
LearningRateSchedule
))
@
parameterized
.
named_parameters
(
@
parameterized
.
named_parameters
((
'exponential'
,
'exponential'
),
(
'exponential'
,
'exponential'
),
(
'cosine_with_warmup'
,
'cosine_with_warmup'
))
(
'cosine_with_warmup'
,
'cosine_with_warmup'
))
def
test_learning_rate_with_decay_and_warmup
(
self
,
lr_decay_type
):
def
test_learning_rate_with_decay_and_warmup
(
self
,
lr_decay_type
):
"""Basic smoke test for syntax."""
"""Basic smoke test for syntax."""
params
=
base_configs
.
LearningRateConfig
(
params
=
base_configs
.
LearningRateConfig
(
...
...
Prev
1
…
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment