Commit b0ccdb11 authored by Shixin Luo's avatar Shixin Luo
Browse files

resolve conflict with master

parents e61588cd 1611a8c5
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for nms.py."""
# Import libraries
import numpy as np
import tensorflow as tf
from official.vision.beta.ops import nms
class SortedNonMaxSuppressionTest(tf.test.TestCase):
def setUp(self):
super(SortedNonMaxSuppressionTest, self).setUp()
self.boxes_data = [[[0, 0, 1, 1], [0, 0.2, 1, 1.2], [0, 0.4, 1, 1.4],
[0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]],
[[0, 2, 1, 2], [0, 0.8, 1, 1.8], [0, 0.6, 1, 1.6],
[0, 0.4, 1, 1.4], [0, 0.2, 1, 1.2], [0, 0, 1, 1]]]
self.scores_data = [[0.9, 0.7, 0.6, 0.5, 0.4, 0.3],
[0.8, 0.7, 0.6, 0.5, 0.4, 0.3]]
self.max_output_size = 6
self.iou_threshold = 0.5
def testSortedNonMaxSuppressionOnTPU(self):
boxes_np = np.array(self.boxes_data, dtype=np.float32)
scores_np = np.array(self.scores_data, dtype=np.float32)
iou_threshold_np = np.array(self.iou_threshold, dtype=np.float32)
boxes = tf.constant(boxes_np)
scores = tf.constant(scores_np)
iou_threshold = tf.constant(iou_threshold_np)
# Runs on TPU.
strategy = tf.distribute.experimental.TPUStrategy()
with strategy.scope():
scores_tpu, boxes_tpu = nms.sorted_non_max_suppression_padded(
boxes=boxes,
scores=scores,
max_output_size=self.max_output_size,
iou_threshold=iou_threshold)
self.assertEqual(boxes_tpu.numpy().shape, (2, self.max_output_size, 4))
self.assertAllClose(scores_tpu.numpy(),
[[0.9, 0.6, 0.4, 0.3, 0., 0.],
[0.8, 0.7, 0.5, 0.3, 0., 0.]])
def testSortedNonMaxSuppressionOnCPU(self):
boxes_np = np.array(self.boxes_data, dtype=np.float32)
scores_np = np.array(self.scores_data, dtype=np.float32)
iou_threshold_np = np.array(self.iou_threshold, dtype=np.float32)
boxes = tf.constant(boxes_np)
scores = tf.constant(scores_np)
iou_threshold = tf.constant(iou_threshold_np)
# Runs on CPU.
scores_cpu, boxes_cpu = nms.sorted_non_max_suppression_padded(
boxes=boxes,
scores=scores,
max_output_size=self.max_output_size,
iou_threshold=iou_threshold)
self.assertEqual(boxes_cpu.numpy().shape, (2, self.max_output_size, 4))
self.assertAllClose(scores_cpu.numpy(),
[[0.9, 0.6, 0.4, 0.3, 0., 0.],
[0.8, 0.7, 0.5, 0.3, 0., 0.]])
def testSortedNonMaxSuppressionOnTPUSpeed(self):
boxes_np = np.random.rand(2, 12000, 4).astype(np.float32)
scores_np = np.random.rand(2, 12000).astype(np.float32)
iou_threshold_np = np.array(0.7, dtype=np.float32)
boxes = tf.constant(boxes_np)
scores = tf.constant(scores_np)
iou_threshold = tf.constant(iou_threshold_np)
# Runs on TPU.
strategy = tf.distribute.experimental.TPUStrategy()
with strategy.scope():
scores_tpu, boxes_tpu = nms.sorted_non_max_suppression_padded(
boxes=boxes,
scores=scores,
max_output_size=2000,
iou_threshold=iou_threshold)
self.assertEqual(scores_tpu.numpy().shape, (2, 2000))
self.assertEqual(boxes_tpu.numpy().shape, (2, 2000, 4))
if __name__ == '__main__':
tf.test.main()
......@@ -159,12 +159,12 @@ def multilevel_crop_and_resize(features,
with tf.name_scope('multilevel_crop_and_resize'):
levels = list(features.keys())
min_level = min(levels)
max_level = max(levels)
min_level = int(min(levels))
max_level = int(max(levels))
batch_size, max_feature_height, max_feature_width, num_filters = (
features[min_level].get_shape().as_list())
features[str(min_level)].get_shape().as_list())
if batch_size is None:
batch_size = tf.shape(features[min_level])[0]
batch_size = tf.shape(features[str(min_level)])[0]
_, num_boxes, _ = boxes.get_shape().as_list()
# Stack feature pyramid into a features_all of shape
......@@ -173,13 +173,13 @@ def multilevel_crop_and_resize(features,
feature_heights = []
feature_widths = []
for level in range(min_level, max_level + 1):
shape = features[level].get_shape().as_list()
shape = features[str(level)].get_shape().as_list()
feature_heights.append(shape[1])
feature_widths.append(shape[2])
# Concat tensor of [batch_size, height_l * width_l, num_filters] for each
# levels.
features_all.append(
tf.reshape(features[level], [batch_size, -1, num_filters]))
tf.reshape(features[str(level)], [batch_size, -1, num_filters]))
features_r2 = tf.reshape(tf.concat(features_all, 1), [-1, num_filters])
# Calculate height_l * width_l for each level.
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for spatial_transform_ops.py."""
# Import libraries
from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from official.vision.beta.ops import spatial_transform_ops
class MultiLevelCropAndResizeTest(tf.test.TestCase):
def test_multilevel_crop_and_resize_square(self):
"""Example test case.
Input =
[
[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11],
[12, 13, 14, 15],
]
output_size = 2x2
box =
[
[[0, 0, 2, 2]]
]
Gathered data =
[
[0, 1, 1, 2],
[4, 5, 5, 6],
[4, 5, 5, 6],
[8, 9, 9, 10],
]
Interpolation kernel =
[
[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1],
]
Output =
[
[2.5, 3.5],
[6.5, 7.5]
]
"""
input_size = 4
min_level = 0
max_level = 0
batch_size = 1
output_size = 2
num_filters = 1
features = {}
for level in range(min_level, max_level + 1):
feat_size = int(input_size / 2**level)
features[level] = tf.range(
batch_size * feat_size * feat_size * num_filters, dtype=tf.float32)
features[level] = tf.reshape(
features[level], [batch_size, feat_size, feat_size, num_filters])
boxes = tf.constant([
[[0, 0, 2, 2]],
], dtype=tf.float32)
tf_roi_features = spatial_transform_ops.multilevel_crop_and_resize(
features, boxes, output_size)
roi_features = tf_roi_features.numpy()
self.assertAllClose(
roi_features,
np.array([[2.5, 3.5],
[6.5,
7.5]]).reshape([batch_size, 1, output_size, output_size, 1]))
def test_multilevel_crop_and_resize_rectangle(self):
"""Example test case.
Input =
[
[0, 1, 2, 3],
[4, 5, 6, 7],
[8, 9, 10, 11],
[12, 13, 14, 15],
]
output_size = 2x2
box =
[
[[0, 0, 2, 3]]
]
Box vertices =
[
[[0.5, 0.75], [0.5, 2.25]],
[[1.5, 0.75], [1.5, 2.25]],
]
Gathered data =
[
[0, 1, 2, 3],
[4, 5, 6, 7],
[4, 5, 6, 7],
[8, 9, 10, 11],
]
Interpolation kernel =
[
[0.5 1.5 1.5 0.5],
[0.5 1.5 1.5 0.5],
[0.5 1.5 1.5 0.5],
[0.5 1.5 1.5 0.5],
]
Output =
[
[2.75, 4.25],
[6.75, 8.25]
]
"""
input_size = 4
min_level = 0
max_level = 0
batch_size = 1
output_size = 2
num_filters = 1
features = {}
for level in range(min_level, max_level + 1):
feat_size = int(input_size / 2**level)
features[level] = tf.range(
batch_size * feat_size * feat_size * num_filters, dtype=tf.float32)
features[level] = tf.reshape(
features[level], [batch_size, feat_size, feat_size, num_filters])
boxes = tf.constant([
[[0, 0, 2, 3]],
], dtype=tf.float32)
tf_roi_features = spatial_transform_ops.multilevel_crop_and_resize(
features, boxes, output_size)
roi_features = tf_roi_features.numpy()
self.assertAllClose(
roi_features,
np.array([[2.75, 4.25],
[6.75,
8.25]]).reshape([batch_size, 1, output_size, output_size,
1]))
def test_multilevel_crop_and_resize_two_boxes(self):
"""Test two boxes."""
input_size = 4
min_level = 0
max_level = 0
batch_size = 1
output_size = 2
num_filters = 1
features = {}
for level in range(min_level, max_level + 1):
feat_size = int(input_size / 2**level)
features[level] = tf.range(
batch_size * feat_size * feat_size * num_filters, dtype=tf.float32)
features[level] = tf.reshape(
features[level], [batch_size, feat_size, feat_size, num_filters])
boxes = tf.constant([
[[0, 0, 2, 2], [0, 0, 2, 3]],
], dtype=tf.float32)
tf_roi_features = spatial_transform_ops.multilevel_crop_and_resize(
features, boxes, output_size)
roi_features = tf_roi_features.numpy()
self.assertAllClose(
roi_features,
np.array([[[2.5, 3.5], [6.5, 7.5]], [[2.75, 4.25], [6.75, 8.25]]
]).reshape([batch_size, 2, output_size, output_size, 1]))
def test_multilevel_crop_and_resize_feature_level_assignment(self):
"""Test feature level assignment."""
input_size = 640
min_level = 2
max_level = 5
batch_size = 1
output_size = 2
num_filters = 1
features = {}
for level in range(min_level, max_level + 1):
feat_size = int(input_size / 2**level)
features[level] = float(level) * tf.ones(
[batch_size, feat_size, feat_size, num_filters], dtype=tf.float32)
boxes = tf.constant(
[
[
[0, 0, 111, 111], # Level 2.
[0, 0, 113, 113], # Level 3.
[0, 0, 223, 223], # Level 3.
[0, 0, 225, 225], # Level 4.
[0, 0, 449, 449]
], # Level 5.
],
dtype=tf.float32)
tf_roi_features = spatial_transform_ops.multilevel_crop_and_resize(
features, boxes, output_size)
roi_features = tf_roi_features.numpy()
self.assertAllClose(roi_features[0, 0], 2 * np.ones((2, 2, 1)))
self.assertAllClose(roi_features[0, 1], 3 * np.ones((2, 2, 1)))
self.assertAllClose(roi_features[0, 2], 3 * np.ones((2, 2, 1)))
self.assertAllClose(roi_features[0, 3], 4 * np.ones((2, 2, 1)))
self.assertAllClose(roi_features[0, 4], 5 * np.ones((2, 2, 1)))
def test_multilevel_crop_and_resize_large_input(self):
"""Test 512 boxes on TPU."""
input_size = 1408
min_level = 2
max_level = 6
batch_size = 2
num_boxes = 512
num_filters = 256
output_size = 7
strategy = tf.distribute.experimental.TPUStrategy()
with strategy.scope():
features = {}
for level in range(min_level, max_level + 1):
feat_size = int(input_size / 2**level)
features[level] = tf.constant(
np.reshape(
np.arange(
batch_size * feat_size * feat_size * num_filters,
dtype=np.float32),
[batch_size, feat_size, feat_size, num_filters]),
dtype=tf.bfloat16)
boxes = np.array([
[[0, 0, 256, 256]]*num_boxes,
], dtype=np.float32)
boxes = np.tile(boxes, [batch_size, 1, 1])
tf_boxes = tf.constant(boxes, dtype=tf.float32)
tf_roi_features = spatial_transform_ops.multilevel_crop_and_resize(
features, tf_boxes)
roi_features = tf_roi_features.numpy()
self.assertEqual(
roi_features.shape,
(batch_size, num_boxes, output_size, output_size, num_filters))
class CropMaskInTargetBoxTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
(False),
(True),
)
def test_crop_mask_in_target_box(self, use_einsum):
batch_size = 1
num_masks = 2
height = 2
width = 2
output_size = 2
strategy = tf.distribute.experimental.TPUStrategy()
with strategy.scope():
masks = tf.ones([batch_size, num_masks, height, width])
boxes = tf.constant(
[[0., 0., 1., 1.],
[0., 0., 1., 1.]])
target_boxes = tf.constant(
[[0., 0., 1., 1.],
[-1., -1., 1., 1.]])
expected_outputs = np.array([
[[[1., 1.],
[1., 1.]],
[[0., 0.],
[0., 1.]]]])
boxes = tf.reshape(boxes, [batch_size, num_masks, 4])
target_boxes = tf.reshape(target_boxes, [batch_size, num_masks, 4])
tf_cropped_masks = spatial_transform_ops.crop_mask_in_target_box(
masks, boxes, target_boxes, output_size, use_einsum=use_einsum)
cropped_masks = tf_cropped_masks.numpy()
self.assertAllEqual(cropped_masks, expected_outputs)
if __name__ == '__main__':
tf.test.main()
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Base class for model export."""
import abc
import tensorflow as tf
def _decode_image(encoded_image_bytes):
image_tensor = tf.image.decode_image(encoded_image_bytes, channels=3)
image_tensor.set_shape((None, None, 3))
return image_tensor
def _decode_tf_example(tf_example_string_tensor):
keys_to_features = {'image/encoded': tf.io.FixedLenFeature((), tf.string)}
parsed_tensors = tf.io.parse_single_example(
serialized=tf_example_string_tensor, features=keys_to_features)
image_tensor = _decode_image(parsed_tensors['image/encoded'])
return image_tensor
class ExportModule(tf.Module, metaclass=abc.ABCMeta):
"""Base Export Module."""
def __init__(self, params, batch_size, input_image_size, model=None):
"""Initializes a module for export.
Args:
params: Experiment params.
batch_size: Int or None.
input_image_size: List or Tuple of height, width of the input image.
model: A tf.keras.Model instance to be exported.
"""
super(ExportModule, self).__init__()
self._params = params
self._batch_size = batch_size
self._input_image_size = input_image_size
self._model = model
@abc.abstractmethod
def build_model(self):
"""Builds model and sets self._model."""
@abc.abstractmethod
def _run_inference_on_image_tensors(self, images):
"""Runs inference on images."""
@tf.function
def inference_from_image_tensors(self, input_tensor):
return dict(outputs=self._run_inference_on_image_tensors(input_tensor))
@tf.function
def inference_from_image_bytes(self, input_tensor):
with tf.device('cpu:0'):
images = tf.nest.map_structure(
tf.identity,
tf.map_fn(
_decode_image,
elems=input_tensor,
fn_output_signature=tf.TensorSpec(
shape=self._input_image_size + [3], dtype=tf.uint8),
parallel_iterations=32))
images = tf.stack(images)
return dict(outputs=self._run_inference_on_image_tensors(images))
@tf.function
def inference_from_tf_example(self, input_tensor):
with tf.device('cpu:0'):
images = tf.nest.map_structure(
tf.identity,
tf.map_fn(
_decode_tf_example,
elems=input_tensor,
fn_output_signature=tf.TensorSpec(
shape=self._input_image_size + [3], dtype=tf.uint8),
dtype=tf.uint8,
parallel_iterations=32))
images = tf.stack(images)
return dict(outputs=self._run_inference_on_image_tensors(images))
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Vision models export binary for serving/inference.
To export a trained checkpoint in saved_model format (shell script):
EXPERIMENT_TYPE = XX
CHECKPOINT_PATH = XX
EXPORT_DIR_PATH = XX
export_saved_model --experiment=${EXPERIMENT_TYPE} \
--export_dir=${EXPORT_DIR_PATH}/ \
--checkpoint_path=${CHECKPOINT_PATH} \
--batch_size=2 \
--input_image_size=224,224
To serve (python):
export_dir_path = XX
input_type = XX
input_images = XX
imported = tf.saved_model.load(export_dir_path)
model_fn = .signatures['serving_default']
output = model_fn(input_images)
"""
import os
from absl import app
from absl import flags
import tensorflow.compat.v2 as tf
from official.common import registry_imports # pylint: disable=unused-import
from official.core import exp_factory
from official.core import train_utils
from official.modeling import hyperparams
from official.vision.beta import configs
from official.vision.beta.serving import image_classification
FLAGS = flags.FLAGS
flags.DEFINE_string(
'experiment', None, 'experiment type, e.g. retinanet_resnetfpn_coco')
flags.DEFINE_string('export_dir', None, 'The export directory.')
flags.DEFINE_string('checkpoint_path', None, 'Checkpoint path.')
flags.DEFINE_multi_string(
'config_file',
default=None,
help='YAML/JSON files which specifies overrides. The override order '
'follows the order of args. Note that each file '
'can be used as an override template to override the default parameters '
'specified in Python. If the same parameter is specified in both '
'`--config_file` and `--params_override`, `config_file` will be used '
'first, followed by params_override.')
flags.DEFINE_string(
'params_override', '',
'The JSON/YAML file or string which specifies the parameter to be overriden'
' on top of `config_file` template.')
flags.DEFINE_integer(
'batch_size', None, 'The batch size.')
flags.DEFINE_string(
'input_type', 'image_tensor',
'One of `image_tensor`, `image_bytes`, `tf_example`.')
flags.DEFINE_string(
'input_image_size', '224,224',
'The comma-separated string of two integers representing the height,width '
'of the input to the model.')
def export_inference_graph(input_type, batch_size, input_image_size, params,
checkpoint_path, export_dir):
"""Exports inference graph for the model specified in the exp config.
Saved model is stored at export_dir/saved_model, checkpoint is saved
at export_dir/checkpoint, and params is saved at export_dir/params.yaml.
Args:
input_type: One of `image_tensor`, `image_bytes`, `tf_example`.
batch_size: 'int', or None.
input_image_size: List or Tuple of height and width.
params: Experiment params.
checkpoint_path: Trained checkpoint path or directory.
export_dir: Export directory path.
"""
output_checkpoint_directory = os.path.join(export_dir, 'checkpoint')
output_saved_model_directory = os.path.join(export_dir, 'saved_model')
if isinstance(params.task,
configs.image_classification.ImageClassificationTask):
export_module = image_classification.ClassificationModule(
params=params,
batch_size=batch_size,
input_image_size=input_image_size)
else:
raise ValueError('Export module not implemented for {} task.'.format(
type(params.task)))
model = export_module.build_model()
ckpt = tf.train.Checkpoint(model=model)
ckpt_dir_or_file = checkpoint_path
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
status = ckpt.restore(ckpt_dir_or_file).expect_partial()
if input_type == 'image_tensor':
input_signature = tf.TensorSpec(
shape=[batch_size, input_image_size[0], input_image_size[1], 3],
dtype=tf.uint8)
signatures = {
'serving_default':
export_module.inference_from_image.get_concrete_function(
input_signature)
}
elif input_type == 'image_bytes':
input_signature = tf.TensorSpec(shape=[batch_size], dtype=tf.string)
signatures = {
'serving_default':
export_module.inference_from_image_bytes.get_concrete_function(
input_signature)
}
elif input_type == 'tf_example':
input_signature = tf.TensorSpec(shape=[batch_size], dtype=tf.string)
signatures = {
'serving_default':
export_module.inference_from_tf_example.get_concrete_function(
input_signature)
}
else:
raise ValueError('Unrecognized `input_type`')
status.assert_existing_objects_matched()
ckpt.save(os.path.join(output_checkpoint_directory, 'ckpt'))
tf.saved_model.save(export_module,
output_saved_model_directory,
signatures=signatures)
train_utils.serialize_config(params, export_dir)
def main(_):
params = exp_factory.get_exp_config(FLAGS.experiment)
for config_file in FLAGS.config_file or []:
params = hyperparams.override_params_dict(
params, config_file, is_strict=True)
if FLAGS.params_override:
params = hyperparams.override_params_dict(
params, FLAGS.params_override, is_strict=True)
params.validate()
params.lock()
export_inference_graph(
input_type=FLAGS.input_type,
batch_size=FLAGS.batch_size,
input_image_size=[int(x) for x in FLAGS.input_image_size.split(',')],
params=params,
checkpoint_path=FLAGS.checkpoint_path,
export_dir=FLAGS.export_dir)
if __name__ == '__main__':
app.run(main)
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Detection input and model functions for serving/inference."""
import tensorflow as tf
from official.vision.beta.modeling import factory
from official.vision.beta.ops import preprocess_ops
from official.vision.beta.serving import export_base
MEAN_RGB = (0.485 * 255, 0.456 * 255, 0.406 * 255)
STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)
class ClassificationModule(export_base.ExportModule):
"""classification Module."""
def build_model(self):
input_specs = tf.keras.layers.InputSpec(
shape=[self._batch_size] + self._input_image_size + [3])
self._model = factory.build_classification_model(
input_specs=input_specs,
model_config=self._params.task.model,
l2_regularizer=None)
return self._model
def _build_inputs(self, image):
"""Builds classification model inputs for serving."""
# Center crops and resizes image.
image = preprocess_ops.center_crop_image(image)
image = tf.image.resize(
image, self._input_image_size, method=tf.image.ResizeMethod.BILINEAR)
image = tf.reshape(
image, [self._input_image_size[0], self._input_image_size[1], 3])
# Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(image,
offset=MEAN_RGB,
scale=STDDEV_RGB)
return image
def _run_inference_on_image_tensors(self, images):
"""Cast image to float and run inference.
Args:
images: uint8 Tensor of shape [batch_size, None, None, 3]
Returns:
Tensor holding classification output logits.
"""
with tf.device('cpu:0'):
images = tf.cast(images, dtype=tf.float32)
images = tf.nest.map_structure(
tf.identity,
tf.map_fn(
self._build_inputs, elems=images,
fn_output_signature=tf.TensorSpec(
shape=self._input_image_size + [3], dtype=tf.float32),
parallel_iterations=32
)
)
logits = self._model(images, training=False)
return logits
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test for image classification export lib."""
import io
import os
from absl.testing import parameterized
import numpy as np
from PIL import Image
import tensorflow as tf
from official.common import registry_imports # pylint: disable=unused-import
from official.core import exp_factory
from official.vision.beta.serving import image_classification
class ImageClassificationExportTest(tf.test.TestCase, parameterized.TestCase):
def _get_classification_module(self):
params = exp_factory.get_exp_config('resnet_imagenet')
params.task.model.backbone.resnet.model_id = 18
classification_module = image_classification.ClassificationModule(
params, batch_size=1, input_image_size=[224, 224])
return classification_module
def _export_from_module(self, module, input_type, save_directory):
if input_type == 'image_tensor':
input_signature = tf.TensorSpec(shape=[None, 224, 224, 3], dtype=tf.uint8)
signatures = {
'serving_default':
module.inference_from_image_tensors.get_concrete_function(
input_signature)
}
elif input_type == 'image_bytes':
input_signature = tf.TensorSpec(shape=[None], dtype=tf.string)
signatures = {
'serving_default':
module.inference_from_image_bytes.get_concrete_function(
input_signature)
}
elif input_type == 'tf_example':
input_signature = tf.TensorSpec(shape=[None], dtype=tf.string)
signatures = {
'serving_default':
module.inference_from_tf_example.get_concrete_function(
input_signature)
}
else:
raise ValueError('Unrecognized `input_type`')
tf.saved_model.save(module,
save_directory,
signatures=signatures)
def _get_dummy_input(self, input_type):
"""Get dummy input for the given input type."""
if input_type == 'image_tensor':
return tf.zeros((1, 224, 224, 3), dtype=np.uint8)
elif input_type == 'image_bytes':
image = Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8))
byte_io = io.BytesIO()
image.save(byte_io, 'PNG')
return [byte_io.getvalue()]
elif input_type == 'tf_example':
image_tensor = tf.zeros((224, 224, 3), dtype=tf.uint8)
encoded_jpeg = tf.image.encode_jpeg(tf.constant(image_tensor)).numpy()
example = tf.train.Example(
features=tf.train.Features(
feature={
'image/encoded':
tf.train.Feature(
bytes_list=tf.train.BytesList(value=[encoded_jpeg])),
})).SerializeToString()
return [example]
@parameterized.parameters(
{'input_type': 'image_tensor'},
{'input_type': 'image_bytes'},
{'input_type': 'tf_example'},
)
def test_export(self, input_type='image_tensor'):
tmp_dir = self.get_temp_dir()
module = self._get_classification_module()
model = module.build_model()
self._export_from_module(module, input_type, tmp_dir)
self.assertTrue(os.path.exists(os.path.join(tmp_dir, 'saved_model.pb')))
self.assertTrue(os.path.exists(
os.path.join(tmp_dir, 'variables', 'variables.index')))
self.assertTrue(os.path.exists(
os.path.join(tmp_dir, 'variables', 'variables.data-00000-of-00001')))
imported = tf.saved_model.load(tmp_dir)
classification_fn = imported.signatures['serving_default']
images = self._get_dummy_input(input_type)
processed_images = tf.nest.map_structure(
tf.stop_gradient,
tf.map_fn(
module._build_inputs,
elems=tf.zeros((1, 224, 224, 3), dtype=tf.uint8),
fn_output_signature=tf.TensorSpec(
shape=[224, 224, 3], dtype=tf.float32)))
expected_output = model(processed_images, training=False)
out = classification_fn(tf.constant(images))
self.assertAllClose(out['outputs'].numpy(), expected_output.numpy())
if __name__ == '__main__':
tf.test.main()
......@@ -18,3 +18,4 @@
from official.vision.beta.tasks import image_classification
from official.vision.beta.tasks import maskrcnn
from official.vision.beta.tasks import retinanet
from official.vision.beta.tasks import video_classification
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for MaskRCNN task."""
# pylint: disable=unused-import
from absl.testing import parameterized
import orbit
import tensorflow as tf
from official.core import exp_factory
from official.modeling import optimization
from official.vision import beta
from official.vision.beta.tasks import maskrcnn
class RetinaNetTaskTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
("fasterrcnn_resnetfpn_coco", True),
("fasterrcnn_resnetfpn_coco", False),
("maskrcnn_resnetfpn_coco", True),
("maskrcnn_resnetfpn_coco", False),
)
def test_retinanet_task_train(self, test_config, is_training):
"""RetinaNet task test for training and val using toy configs."""
config = exp_factory.get_exp_config(test_config)
tf.keras.mixed_precision.experimental.Policy("mixed_bfloat16")
# modify config to suit local testing
config.trainer.steps_per_loop = 1
config.task.train_data.global_batch_size = 2
config.task.model.input_size = [384, 384, 3]
config.train_steps = 2
config.task.train_data.shuffle_buffer_size = 10
config.task.train_data.input_path = "coco/train-00000-of-00256.tfrecord"
config.task.validation_data.global_batch_size = 2
config.task.validation_data.input_path = "coco/val-00000-of-00032.tfrecord"
task = maskrcnn.MaskRCNNTask(config.task)
model = task.build_model()
metrics = task.build_metrics(training=is_training)
strategy = tf.distribute.get_strategy()
data_config = config.task.train_data if is_training else config.task.validation_data
dataset = orbit.utils.make_distributed_dataset(strategy, task.build_inputs,
data_config)
iterator = iter(dataset)
opt_factory = optimization.OptimizerFactory(config.trainer.optimizer_config)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
if is_training:
task.train_step(next(iterator), model, optimizer, metrics=metrics)
else:
task.validation_step(next(iterator), model, metrics=metrics)
if __name__ == "__main__":
tf.test.main()
......@@ -131,7 +131,7 @@ class RetinaNetTask(base_task.Task):
def build_losses(self, outputs, labels, aux_losses=None):
"""Build RetinaNet losses."""
params = self.task_config
cls_loss_fn = keras_cv.FocalLoss(
cls_loss_fn = keras_cv.losses.FocalLoss(
alpha=params.losses.focal_loss_alpha,
gamma=params.losses.focal_loss_gamma,
reduction=tf.keras.losses.Reduction.SUM)
......@@ -145,14 +145,14 @@ class RetinaNetTask(base_task.Task):
num_positives = tf.reduce_sum(box_sample_weight) + 1.0
cls_sample_weight = cls_sample_weight / num_positives
box_sample_weight = box_sample_weight / num_positives
y_true_cls = keras_cv.multi_level_flatten(
y_true_cls = keras_cv.losses.multi_level_flatten(
labels['cls_targets'], last_dim=None)
y_true_cls = tf.one_hot(y_true_cls, params.model.num_classes)
y_pred_cls = keras_cv.multi_level_flatten(
y_pred_cls = keras_cv.losses.multi_level_flatten(
outputs['cls_outputs'], last_dim=params.model.num_classes)
y_true_box = keras_cv.multi_level_flatten(
y_true_box = keras_cv.losses.multi_level_flatten(
labels['box_targets'], last_dim=4)
y_pred_box = keras_cv.multi_level_flatten(
y_pred_box = keras_cv.losses.multi_level_flatten(
outputs['box_outputs'], last_dim=4)
cls_loss = cls_loss_fn(
......
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for RetinaNet task."""
# pylint: disable=unused-import
from absl.testing import parameterized
import orbit
import tensorflow as tf
from official.core import exp_factory
from official.modeling import optimization
from official.vision import beta
from official.vision.beta.tasks import retinanet
class RetinaNetTaskTest(parameterized.TestCase, tf.test.TestCase):
@parameterized.parameters(
("retinanet_resnetfpn_coco", True),
("retinanet_spinenet_coco", True),
)
def test_retinanet_task_train(self, test_config, is_training):
"""RetinaNet task test for training and val using toy configs."""
config = exp_factory.get_exp_config(test_config)
# modify config to suit local testing
config.trainer.steps_per_loop = 1
config.task.train_data.global_batch_size = 2
config.task.validation_data.global_batch_size = 2
config.task.train_data.shuffle_buffer_size = 4
config.task.validation_data.shuffle_buffer_size = 4
config.train_steps = 2
task = retinanet.RetinaNetTask(config.task)
model = task.build_model()
metrics = task.build_metrics()
strategy = tf.distribute.get_strategy()
data_config = config.task.train_data if is_training else config.task.validation_data
dataset = orbit.utils.make_distributed_dataset(strategy, task.build_inputs,
data_config)
iterator = iter(dataset)
opt_factory = optimization.OptimizerFactory(config.trainer.optimizer_config)
optimizer = opt_factory.build_optimizer(opt_factory.build_learning_rate())
if is_training:
task.train_step(next(iterator), model, optimizer, metrics=metrics)
else:
task.validation_step(next(iterator), model, metrics=metrics)
if __name__ == "__main__":
tf.test.main()
# Lint as: python3
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Video classification task definition."""
import tensorflow as tf
from official.core import base_task
from official.core import input_reader
from official.core import task_factory
from official.modeling import tf_utils
from official.vision.beta.configs import video_classification as exp_cfg
from official.vision.beta.dataloaders import video_input
from official.vision.beta.modeling import factory
@task_factory.register_task_cls(exp_cfg.VideoClassificationTask)
class VideoClassificationTask(base_task.Task):
"""A task for video classification."""
def build_model(self):
"""Builds video classification model."""
input_specs = tf.keras.layers.InputSpec(shape=[None, None, None, None, 3])
l2_weight_decay = self.task_config.losses.l2_weight_decay
# Divide weight decay by 2.0 to match the implementation of tf.nn.l2_loss.
# (https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/l2)
# (https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss)
l2_regularizer = (tf.keras.regularizers.l2(
l2_weight_decay / 2.0) if l2_weight_decay else None)
model = factory.build_video_classification_model(
input_specs=input_specs,
model_config=self.task_config.model,
num_classes=self.task_config.train_data.num_classes,
l2_regularizer=l2_regularizer)
return model
def build_inputs(self, params: exp_cfg.DataConfig, input_context=None):
"""Builds classification input."""
decoder = video_input.Decoder()
decoder_fn = decoder.decode
parser = video_input.Parser(input_params=params)
postprocess_fn = video_input.PostBatchProcessor(params)
reader = input_reader.InputReader(
params,
dataset_fn=tf.data.TFRecordDataset,
decoder_fn=decoder_fn,
parser_fn=parser.parse_fn(params.is_training),
postprocess_fn=postprocess_fn)
dataset = reader.read(input_context=input_context)
return dataset
def build_losses(self, labels, model_outputs, aux_losses=None):
"""Sparse categorical cross entropy loss.
Args:
labels: labels.
model_outputs: Output logits of the classifier.
aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.
Returns:
The total loss tensor.
"""
losses_config = self.task_config.losses
if losses_config.one_hot:
total_loss = tf.keras.losses.categorical_crossentropy(
labels,
model_outputs,
from_logits=True,
label_smoothing=losses_config.label_smoothing)
else:
total_loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, model_outputs, from_logits=True)
total_loss = tf_utils.safe_mean(total_loss)
if aux_losses:
total_loss += tf.add_n(aux_losses)
return total_loss
def build_metrics(self, training=True):
"""Gets streaming metrics for training/validation."""
if self.task_config.losses.one_hot:
metrics = [
tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
tf.keras.metrics.TopKCategoricalAccuracy(k=1, name='top_1_accuracy'),
tf.keras.metrics.TopKCategoricalAccuracy(k=5, name='top_5_accuracy')
]
else:
metrics = [
tf.keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
tf.keras.metrics.SparseTopKCategoricalAccuracy(
k=1, name='top_1_accuracy'),
tf.keras.metrics.SparseTopKCategoricalAccuracy(
k=5, name='top_5_accuracy')
]
return metrics
def train_step(self, inputs, model, optimizer, metrics=None):
"""Does forward and backward.
Args:
inputs: a dictionary of input tensors.
model: the model, forward pass definition.
optimizer: the optimizer for this training step.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
with tf.GradientTape() as tape:
outputs = model(features['image'], training=True)
# Casting output layer as float32 is necessary when mixed_precision is
# mixed_float16 or mixed_bfloat16 to ensure output is casted as float32.
outputs = tf.nest.map_structure(
lambda x: tf.cast(x, tf.float32), outputs)
# Computes per-replica loss.
loss = self.build_losses(
model_outputs=outputs, labels=labels, aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss = loss / num_replicas
# For mixed_precision policy, when LossScaleOptimizer is used, loss is
# scaled for numerical stability.
if isinstance(
optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer):
scaled_loss = optimizer.get_scaled_loss(scaled_loss)
tvars = model.trainable_variables
grads = tape.gradient(scaled_loss, tvars)
# Scales back gradient before apply_gradients when LossScaleOptimizer is
# used.
if isinstance(
optimizer, tf.keras.mixed_precision.experimental.LossScaleOptimizer):
grads = optimizer.get_unscaled_gradients(grads)
# Apply gradient clipping.
if self.task_config.gradient_clip_norm > 0:
grads, _ = tf.clip_by_global_norm(
grads, self.task_config.gradient_clip_norm)
optimizer.apply_gradients(list(zip(grads, tvars)))
logs = {self.loss: loss}
if metrics:
self.process_metrics(metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics})
elif model.compiled_metrics:
self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
logs.update({m.name: m.result() for m in model.metrics})
return logs
def validation_step(self, inputs, model, metrics=None):
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features, labels = inputs
outputs = self.inference_step(features['image'], model)
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
loss = self.build_losses(model_outputs=outputs, labels=labels,
aux_losses=model.losses)
logs = {self.loss: loss}
if metrics:
self.process_metrics(metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics})
elif model.compiled_metrics:
self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
logs.update({m.name: m.result() for m in model.metrics})
return logs
def inference_step(self, inputs, model):
"""Performs the forward step."""
return model(inputs, training=False)
......@@ -19,13 +19,13 @@ from absl import app
from absl import flags
import gin
from official.common import distribute_utils
from official.common import flags as tfm_flags
from official.common import registry_imports # pylint: disable=unused-import
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.modeling import performance
from official.utils.misc import distribution_utils
FLAGS = flags.FLAGS
......@@ -46,7 +46,7 @@ def main(_):
if params.runtime.mixed_precision_dtype:
performance.set_mixed_precision_policy(params.runtime.mixed_precision_dtype,
params.runtime.loss_scale)
distribution_strategy = distribution_utils.get_distribution_strategy(
distribution_strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
......
......@@ -21,11 +21,11 @@ from __future__ import print_function
import collections
import tensorflow as tf
from official.vision import keras_cv
from official.vision.detection.utils.object_detection import argmax_matcher
from official.vision.detection.utils.object_detection import balanced_positive_negative_sampler
from official.vision.detection.utils.object_detection import box_list
from official.vision.detection.utils.object_detection import faster_rcnn_box_coder
from official.vision.detection.utils.object_detection import region_similarity_calculator
from official.vision.detection.utils.object_detection import target_assigner
......@@ -134,7 +134,7 @@ class AnchorLabeler(object):
upper-bound threshold to assign negative labels for anchors. An anchor
with a score below the threshold is labeled negative.
"""
similarity_calc = region_similarity_calculator.IouSimilarity()
similarity_calc = keras_cv.ops.IouSimilarity()
matcher = argmax_matcher.ArgMaxMatcher(
match_threshold,
unmatched_threshold=unmatched_threshold,
......
......@@ -14,28 +14,19 @@
# ==============================================================================
"""Main function to train various object detection models."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import functools
import pprint
# pylint: disable=g-bad-import-order
# Import libraries
import tensorflow as tf
from absl import app
from absl import flags
from absl import logging
# pylint: enable=g-bad-import-order
import tensorflow as tf
from official.common import distribute_utils
from official.modeling.hyperparams import params_dict
from official.modeling.training import distributed_executor as executor
from official.utils import hyperparams_flags
from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.vision.detection.configs import factory as config_factory
from official.vision.detection.dataloader import input_reader
......@@ -87,9 +78,9 @@ def run_executor(params,
strategy = prebuilt_strategy
else:
strategy_config = params.strategy_config
distribution_utils.configure_cluster(strategy_config.worker_hosts,
strategy_config.task_index)
strategy = distribution_utils.get_distribution_strategy(
distribute_utils.configure_cluster(strategy_config.worker_hosts,
strategy_config.task_index)
strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=params.strategy_type,
num_gpus=strategy_config.num_gpus,
all_reduce_alg=strategy_config.all_reduce_alg,
......
......@@ -151,8 +151,8 @@ class TargetAssigner(object):
groundtruth_weights = tf.ones([num_gt_boxes], dtype=tf.float32)
with tf.control_dependencies(
[unmatched_shape_assert, labels_and_box_shapes_assert]):
match_quality_matrix = self._similarity_calc.compare(
groundtruth_boxes, anchors)
match_quality_matrix = self._similarity_calc(
groundtruth_boxes.get(), anchors.get())
match = self._matcher.match(match_quality_matrix, **params)
reg_targets = self._create_regression_targets(anchors, groundtruth_boxes,
match)
......
......@@ -29,16 +29,18 @@ from official.modeling import optimization
from official.utils.misc import keras_utils
def get_callbacks(model_checkpoint: bool = True,
include_tensorboard: bool = True,
time_history: bool = True,
track_lr: bool = True,
write_model_weights: bool = True,
apply_moving_average: bool = False,
initial_step: int = 0,
batch_size: int = 0,
log_steps: int = 0,
model_dir: str = None) -> List[tf.keras.callbacks.Callback]:
def get_callbacks(
model_checkpoint: bool = True,
include_tensorboard: bool = True,
time_history: bool = True,
track_lr: bool = True,
write_model_weights: bool = True,
apply_moving_average: bool = False,
initial_step: int = 0,
batch_size: int = 0,
log_steps: int = 0,
model_dir: str = None,
backup_and_restore: bool = False) -> List[tf.keras.callbacks.Callback]:
"""Get all callbacks."""
model_dir = model_dir or ''
callbacks = []
......@@ -47,6 +49,10 @@ def get_callbacks(model_checkpoint: bool = True,
callbacks.append(
tf.keras.callbacks.ModelCheckpoint(
ckpt_full_path, save_weights_only=True, verbose=1))
if backup_and_restore:
backup_dir = os.path.join(model_dir, 'tmp')
callbacks.append(
tf.keras.callbacks.experimental.BackupAndRestore(backup_dir))
if include_tensorboard:
callbacks.append(
CustomTensorBoard(
......
......@@ -23,11 +23,10 @@ from absl import app
from absl import flags
from absl import logging
import tensorflow as tf
from official.common import distribute_utils
from official.modeling import hyperparams
from official.modeling import performance
from official.utils import hyperparams_flags
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.vision.image_classification import callbacks as custom_callbacks
from official.vision.image_classification import dataset_factory
......@@ -291,17 +290,17 @@ def train_and_eval(
"""Runs the train and eval path using compile/fit."""
logging.info('Running train and eval.')
distribution_utils.configure_cluster(params.runtime.worker_hosts,
params.runtime.task_index)
distribute_utils.configure_cluster(params.runtime.worker_hosts,
params.runtime.task_index)
# Note: for TPUs, strategy and scope should be created before the dataset
strategy = strategy_override or distribution_utils.get_distribution_strategy(
strategy = strategy_override or distribute_utils.get_distribution_strategy(
distribution_strategy=params.runtime.distribution_strategy,
all_reduce_alg=params.runtime.all_reduce_alg,
num_gpus=params.runtime.num_gpus,
tpu_address=params.runtime.tpu)
strategy_scope = distribution_utils.get_strategy_scope(strategy)
strategy_scope = distribute_utils.get_strategy_scope(strategy)
logging.info('Detected %d devices.',
strategy.num_replicas_in_sync if strategy else 1)
......@@ -369,7 +368,8 @@ def train_and_eval(
initial_step=initial_epoch * train_steps,
batch_size=train_builder.global_batch_size,
log_steps=params.train.time_history.log_steps,
model_dir=params.model_dir)
model_dir=params.model_dir,
backup_and_restore=params.train.callbacks.enable_backup_and_restore)
serialize_config(params=params, model_dir=params.model_dir)
......
......@@ -25,9 +25,8 @@ from absl import flags
from absl import logging
import tensorflow as tf
import tensorflow_datasets as tfds
from official.common import distribute_utils
from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
from official.utils.misc import model_helpers
from official.vision.image_classification.resnet import common
......@@ -82,12 +81,12 @@ def run(flags_obj, datasets_override=None, strategy_override=None):
Returns:
Dictionary of training and eval stats.
"""
strategy = strategy_override or distribution_utils.get_distribution_strategy(
strategy = strategy_override or distribute_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus,
tpu_address=flags_obj.tpu)
strategy_scope = distribution_utils.get_strategy_scope(strategy)
strategy_scope = distribute_utils.get_strategy_scope(strategy)
mnist = tfds.builder('mnist', data_dir=flags_obj.data_dir)
if flags_obj.download:
......
......@@ -23,10 +23,9 @@ from absl import flags
from absl import logging
import orbit
import tensorflow as tf
from official.common import distribute_utils
from official.modeling import performance
from official.utils.flags import core as flags_core
from official.utils.misc import distribution_utils
from official.utils.misc import keras_utils
from official.utils.misc import model_helpers
from official.vision.image_classification.resnet import common
......@@ -117,7 +116,7 @@ def run(flags_obj):
else 'channels_last')
tf.keras.backend.set_image_data_format(data_format)
strategy = distribution_utils.get_distribution_strategy(
strategy = distribute_utils.get_distribution_strategy(
distribution_strategy=flags_obj.distribution_strategy,
num_gpus=flags_obj.num_gpus,
all_reduce_alg=flags_obj.all_reduce_alg,
......@@ -144,7 +143,7 @@ def run(flags_obj):
flags_obj.batch_size,
flags_obj.log_steps,
logdir=flags_obj.model_dir if flags_obj.enable_tensorboard else None)
with distribution_utils.get_strategy_scope(strategy):
with distribute_utils.get_strategy_scope(strategy):
runnable = resnet_runnable.ResnetRunnable(flags_obj, time_callback,
per_epoch_steps)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment