Commit 833e6939 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 278969067
parent bf0dc049
......@@ -64,6 +64,11 @@ def run_executor(params,
callbacks=None):
"""Runs Retinanet model on distribution strategy defined by the user."""
if params.architecture.use_bfloat16:
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy(
'mixed_bfloat16')
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy)
model_builder = model_factory.model_generator(params)
if FLAGS.mode == 'train':
......
......@@ -85,7 +85,6 @@ class Model(object):
def __init__(self, params):
self._use_bfloat16 = params.architecture.use_bfloat16
assert not self._use_bfloat16, 'bfloat16 is not supported in Keras yet.'
# Optimization.
self._optimizer_fn = OptimizerFactory(params.train.optimizer)
......
......@@ -318,7 +318,8 @@ class GenerateOneStageDetections(tf.keras.layers.Layer):
boxes = tf.expand_dims(boxes, axis=2)
(nmsed_boxes, nmsed_scores, nmsed_classes,
valid_detections) = self._generate_detections(boxes, scores)
valid_detections) = self._generate_detections(
tf.cast(boxes, tf.float32), tf.cast(scores, tf.float32))
# Adds 1 to offset the background class which has index 0.
nmsed_classes += 1
return nmsed_boxes, nmsed_scores, nmsed_classes, valid_detections
......@@ -92,7 +92,9 @@ class RetinanetModel(base_model.Model):
input_shape = (
params.retinanet_parser.output_size +
[params.retinanet_parser.num_channels])
self._input_layer = tf.keras.layers.Input(shape=input_shape, name='')
self._input_layer = tf.keras.layers.Input(
shape=input_shape, name='',
dtype=tf.bfloat16 if self._use_bfloat16 else tf.float32)
def build_outputs(self, inputs, mode):
backbone_features = self._backbone_fn(
......@@ -101,6 +103,13 @@ class RetinanetModel(base_model.Model):
backbone_features, is_training=(mode == mode_keys.TRAIN))
cls_outputs, box_outputs = self._head_fn(
fpn_features, is_training=(mode == mode_keys.TRAIN))
if self._use_bfloat16:
levels = cls_outputs.keys()
for level in levels:
cls_outputs[level] = tf.cast(cls_outputs[level], tf.float32)
box_outputs[level] = tf.cast(box_outputs[level], tf.float32)
model_outputs = {
'cls_outputs': cls_outputs,
'box_outputs': box_outputs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment