Commit 363a36cd authored by Tomer Kaftan's avatar Tomer Kaftan Committed by A. Unique TensorFlower
Browse files

Make hack in official/vision/detection models that enters the backend keras...

Make hack in official/vision/detection models that enters the backend keras graph stop happening once we enable the Functional API KerasTensors refactoring:

As a workaround hack for the tf op layer conversion being fragile, the detection models have to explicitly enter the Keras backend graph.

When we enable the KerasTensors refactoring of the Functional API internals, the op layer conversion will be much more reliable and this hack will not be necessary. In addition, the hack actually causes the models to break when we enable the refactoring (because it causes tensors to leak out of a graph).

So, this CL changes the existing hack to stop applying once we've enabled the KerasTensors refactoring.

PiperOrigin-RevId: 322229802
parent cdda0906
......@@ -28,7 +28,7 @@ import functools
import tensorflow as tf
from tensorflow.python.keras import backend
from official.vision.detection.modeling.architecture import keras_utils
from official.vision.detection.modeling.architecture import nn_ops
from official.vision.detection.ops import spatial_transform_ops
......@@ -120,7 +120,7 @@ class Fpn(object):
'The minimum backbone level %d should be '%(min(input_levels)) +
'less or equal to FPN minimum level %d.:'%(self._min_level))
backbone_max_level = min(max(input_levels), self._max_level)
with backend.get_graph().as_default(), tf.name_scope('fpn'):
with keras_utils.maybe_enter_backend_graph(), tf.name_scope('fpn'):
# Adds lateral connections.
feats_lateral = {}
for level in range(self._min_level, backbone_max_level + 1):
......
......@@ -22,7 +22,8 @@ import functools
import numpy as np
import tensorflow as tf
from tensorflow.python.keras import backend
from official.vision.detection.modeling.architecture import keras_utils
from official.vision.detection.modeling.architecture import nn_ops
from official.vision.detection.ops import spatial_transform_ops
......@@ -127,7 +128,7 @@ class RpnHead(tf.keras.layers.Layer):
scores_outputs = {}
box_outputs = {}
with backend.get_graph().as_default(), tf.name_scope('rpn_head'):
with keras_utils.maybe_enter_backend_graph(), tf.name_scope('rpn_head'):
for level in range(self._min_level, self._max_level + 1):
scores_output, box_output = self._shared_rpn_heads(
features[level], self._anchors_per_location, level, is_training)
......@@ -249,7 +250,8 @@ class FastrcnnHead(tf.keras.layers.Layer):
predictions.
"""
with backend.get_graph().as_default(), tf.name_scope('fast_rcnn_head'):
with keras_utils.maybe_enter_backend_graph(), tf.name_scope(
'fast_rcnn_head'):
# reshape inputs beofre FC.
_, num_rois, height, width, filters = roi_features.get_shape().as_list()
......@@ -368,7 +370,7 @@ class MaskrcnnHead(tf.keras.layers.Layer):
boxes is not 4.
"""
with backend.get_graph().as_default():
with keras_utils.maybe_enter_backend_graph():
with tf.name_scope('mask_head'):
_, num_rois, height, width, filters = roi_features.get_shape().as_list()
net = tf.reshape(roi_features, [-1, height, width, filters])
......@@ -552,7 +554,8 @@ class RetinanetHead(object):
"""Returns outputs of RetinaNet head."""
class_outputs = {}
box_outputs = {}
with backend.get_graph().as_default(), tf.name_scope('retinanet_head'):
with keras_utils.maybe_enter_backend_graph(), tf.name_scope(
'retinanet_head'):
for level in range(self._min_level, self._max_level + 1):
features = fpn_features[level]
......@@ -644,7 +647,7 @@ class ShapemaskPriorHead(object):
detection_priors: A float Tensor of shape [batch_size * num_instances,
mask_size, mask_size, 1].
"""
with backend.get_graph().as_default(), tf.name_scope('prior_mask'):
with keras_utils.maybe_enter_backend_graph(), tf.name_scope('prior_mask'):
batch_size, num_instances, _ = boxes.get_shape().as_list()
outer_boxes = tf.cast(outer_boxes, tf.float32)
boxes = tf.cast(boxes, tf.float32)
......@@ -807,7 +810,7 @@ class ShapemaskCoarsemaskHead(object):
mask_outputs: instance mask prediction as a float Tensor of shape
[batch_size, num_instances, mask_size, mask_size].
"""
with backend.get_graph().as_default(), tf.name_scope('coarse_mask'):
with keras_utils.maybe_enter_backend_graph(), tf.name_scope('coarse_mask'):
# Transform detection priors to have the same dimension as features.
detection_priors = tf.expand_dims(detection_priors, axis=-1)
detection_priors = self._coarse_mask_fc(detection_priors)
......@@ -939,7 +942,7 @@ class ShapemaskFinemaskHead(object):
"""
# Extract the foreground mean features
# with tf.variable_scope('fine_mask', reuse=tf.AUTO_REUSE):
with backend.get_graph().as_default(), tf.name_scope('fine_mask'):
with keras_utils.maybe_enter_backend_graph(), tf.name_scope('fine_mask'):
mask_probs = tf.nn.sigmoid(mask_logits)
# Compute instance embedding for hard average.
binary_mask = tf.cast(tf.greater(mask_probs, 0.5), features.dtype)
......
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Util functions to integrate with Keras internals."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.keras import backend
try:
from tensorflow.python.keras.engine import keras_tensor # pylint: disable=g-import-not-at-top,unused-import
except ImportError:
keras_tensor = None
class NoOpContextManager(object):
def __enter__(self):
pass
def __exit__(self, *args):
pass
def maybe_enter_backend_graph():
if (keras_tensor is not None) and keras_tensor.keras_tensors_enabled():
return NoOpContextManager()
else:
return backend.get_graph().as_default()
......@@ -25,7 +25,7 @@ from __future__ import print_function
from absl import logging
import tensorflow as tf
from tensorflow.python.keras import backend
from official.vision.detection.modeling.architecture import keras_utils
from official.vision.detection.modeling.architecture import nn_ops
# TODO(b/140112644): Refactor the code with Keras style, i.e. build and call.
......@@ -90,7 +90,7 @@ class Resnet(object):
The values are corresponding feature hierarchy in ResNet with shape
[batch_size, height_l, width_l, num_filters].
"""
with backend.get_graph().as_default():
with keras_utils.maybe_enter_backend_graph():
with tf.name_scope('resnet%s' % self._resnet_depth):
return self._resnet_fn(inputs, is_training)
......
......@@ -24,8 +24,8 @@ import math
from absl import logging
import tensorflow as tf
from tensorflow.python.keras import backend
from official.modeling import tf_utils
from official.vision.detection.modeling.architecture import keras_utils
from official.vision.detection.modeling.architecture import nn_blocks
layers = tf.keras.layers
......@@ -486,7 +486,7 @@ class SpineNetBuilder(object):
self._norm_epsilon = norm_epsilon
def __call__(self, inputs, is_training=None):
with backend.get_graph().as_default():
with keras_utils.maybe_enter_backend_graph():
model = SpineNet(
input_specs=self._input_specs,
min_level=self._min_level,
......
......@@ -20,13 +20,13 @@ from __future__ import print_function
import tensorflow as tf
from tensorflow.python.keras import backend
from official.vision.detection.dataloader import anchor
from official.vision.detection.dataloader import mode_keys
from official.vision.detection.evaluation import factory as eval_factory
from official.vision.detection.modeling import base_model
from official.vision.detection.modeling import losses
from official.vision.detection.modeling.architecture import factory
from official.vision.detection.modeling.architecture import keras_utils
from official.vision.detection.ops import postprocess_ops
from official.vision.detection.ops import roi_ops
from official.vision.detection.ops import spatial_transform_ops
......@@ -297,7 +297,7 @@ class MaskrcnnModel(base_model.Model):
def build_model(self, params, mode):
if self._keras_model is None:
input_layers = self.build_input_layers(self._params, mode)
with backend.get_graph().as_default():
with keras_utils.maybe_enter_backend_graph():
outputs = self.model_outputs(input_layers, mode)
model = tf.keras.models.Model(
......
......@@ -20,12 +20,12 @@ from __future__ import print_function
import tensorflow as tf
from tensorflow.python.keras import backend
from official.vision.detection.dataloader import mode_keys
from official.vision.detection.evaluation import factory as eval_factory
from official.vision.detection.modeling import base_model
from official.vision.detection.modeling import losses
from official.vision.detection.modeling.architecture import factory
from official.vision.detection.modeling.architecture import keras_utils
from official.vision.detection.ops import postprocess_ops
......@@ -120,7 +120,7 @@ class RetinanetModel(base_model.Model):
def build_model(self, params, mode=None):
if self._keras_model is None:
with backend.get_graph().as_default():
with keras_utils.maybe_enter_backend_graph():
outputs = self.model_outputs(self._input_layer, mode)
model = tf.keras.models.Model(
......
......@@ -20,13 +20,13 @@ from __future__ import print_function
import tensorflow as tf
from tensorflow.python.keras import backend
from official.vision.detection.dataloader import anchor
from official.vision.detection.dataloader import mode_keys
from official.vision.detection.evaluation import factory as eval_factory
from official.vision.detection.modeling import base_model
from official.vision.detection.modeling import losses
from official.vision.detection.modeling.architecture import factory
from official.vision.detection.modeling.architecture import keras_utils
from official.vision.detection.ops import postprocess_ops
from official.vision.detection.utils import box_utils
......@@ -265,7 +265,7 @@ class ShapeMaskModel(base_model.Model):
def build_model(self, params, mode):
if self._keras_model is None:
input_layers = self.build_input_layers(self._params, mode)
with backend.get_graph().as_default():
with keras_utils.maybe_enter_backend_graph():
outputs = self.model_outputs(input_layers, mode)
model = tf.keras.models.Model(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment