Unverified Commit a555f1b0 authored by vivek rathod's avatar vivek rathod Committed by GitHub
Browse files

Merged commit includes the following changes: (#8740)

318497061  by rathodv:

    1. Replace strategy.run() with strategy.experimental_run_v2() and replace tensor.ref() with tensor.experimental_ref() to be compatible with TF2.1 runtime on cloud.
    2. Fix expected string in failing PY3 tests.

--
318493408  by aom:

    Implements "Bidirectional Feature Pyramid Network Generators" for BiFPN-based feature extractors (e.g. EfficientDet).

--

PiperOrigin-RevId: 318497061
parent 0f0c7745
...@@ -390,7 +390,7 @@ class DatasetBuilderTest(test_case.TestCase): ...@@ -390,7 +390,7 @@ class DatasetBuilderTest(test_case.TestCase):
return iter1.get_next(), iter2.get_next() return iter1.get_next(), iter2.get_next()
output_dict1, output_dict2 = self.execute(graph_fn, []) output_dict1, output_dict2 = self.execute(graph_fn, [])
self.assertAllEqual(['0'], output_dict1[fields.InputDataFields.source_id]) self.assertAllEqual([b'0'], output_dict1[fields.InputDataFields.source_id])
self.assertEqual([b'1'], output_dict2[fields.InputDataFields.source_id]) self.assertEqual([b'1'], output_dict2[fields.InputDataFields.source_id])
def test_sample_one_of_n_shards(self): def test_sample_one_of_n_shards(self):
......
...@@ -200,7 +200,7 @@ class GenerateContextDataTest(tf.test.TestCase): ...@@ -200,7 +200,7 @@ class GenerateContextDataTest(tf.test.TestCase):
seq_feature_dict['region/label/string'].feature[1].bytes_list.value[:]) seq_feature_dict['region/label/string'].feature[1].bytes_list.value[:])
def assert_expected_key(self, key): def assert_expected_key(self, key):
self.assertAllEqual(key, '01') self.assertAllEqual(key, b'01')
def assert_sorted(self, example_collection): def assert_sorted(self, example_collection):
example_list = list(example_collection) example_list = list(example_collection)
......
...@@ -95,13 +95,13 @@ class CreateCOCOCameraTrapsTfexampleTest(tf.test.TestCase): ...@@ -95,13 +95,13 @@ class CreateCOCOCameraTrapsTfexampleTest(tf.test.TestCase):
.int64_list.value, [1]) .int64_list.value, [1])
self.assertAllEqual( self.assertAllEqual(
example.features.feature['image/object/class/text'] example.features.feature['image/object/class/text']
.bytes_list.value, ['animal']) .bytes_list.value, [b'animal'])
self.assertAllClose( self.assertAllClose(
example.features.feature['image/class/label'] example.features.feature['image/class/label']
.int64_list.value, [1]) .int64_list.value, [1])
self.assertAllEqual( self.assertAllEqual(
example.features.feature['image/class/text'] example.features.feature['image/class/text']
.bytes_list.value, ['animal']) .bytes_list.value, [b'animal'])
# Check other essential attributes. # Check other essential attributes.
self.assertAllEqual( self.assertAllEqual(
...@@ -112,7 +112,7 @@ class CreateCOCOCameraTrapsTfexampleTest(tf.test.TestCase): ...@@ -112,7 +112,7 @@ class CreateCOCOCameraTrapsTfexampleTest(tf.test.TestCase):
[self.IMAGE_WIDTH]) [self.IMAGE_WIDTH])
self.assertAllEqual( self.assertAllEqual(
example.features.feature['image/source_id'].bytes_list.value, example.features.feature['image/source_id'].bytes_list.value,
['im_0']) [b'im_0'])
self.assertTrue( self.assertTrue(
example.features.feature['image/encoded'].bytes_list.value) example.features.feature['image/encoded'].bytes_list.value)
...@@ -134,13 +134,13 @@ class CreateCOCOCameraTrapsTfexampleTest(tf.test.TestCase): ...@@ -134,13 +134,13 @@ class CreateCOCOCameraTrapsTfexampleTest(tf.test.TestCase):
.int64_list.value, [1]) .int64_list.value, [1])
self.assertAllEqual( self.assertAllEqual(
example.features.feature['image/object/class/text'] example.features.feature['image/object/class/text']
.bytes_list.value, ['animal']) .bytes_list.value, [b'animal'])
self.assertAllClose( self.assertAllClose(
example.features.feature['image/class/label'] example.features.feature['image/class/label']
.int64_list.value, [1]) .int64_list.value, [1])
self.assertAllEqual( self.assertAllEqual(
example.features.feature['image/class/text'] example.features.feature['image/class/text']
.bytes_list.value, ['animal']) .bytes_list.value, [b'animal'])
# Check other essential attributes. # Check other essential attributes.
self.assertAllEqual( self.assertAllEqual(
...@@ -151,7 +151,7 @@ class CreateCOCOCameraTrapsTfexampleTest(tf.test.TestCase): ...@@ -151,7 +151,7 @@ class CreateCOCOCameraTrapsTfexampleTest(tf.test.TestCase):
[self.IMAGE_WIDTH]) [self.IMAGE_WIDTH])
self.assertAllEqual( self.assertAllEqual(
example.features.feature['image/source_id'].bytes_list.value, example.features.feature['image/source_id'].bytes_list.value,
['im_0']) [b'im_0'])
self.assertTrue( self.assertTrue(
example.features.feature['image/encoded'].bytes_list.value) example.features.feature['image/encoded'].bytes_list.value)
......
...@@ -239,13 +239,13 @@ class GenerateEmbeddingData(tf.test.TestCase): ...@@ -239,13 +239,13 @@ class GenerateEmbeddingData(tf.test.TestCase):
.int64_list.value, [5]) .int64_list.value, [5])
self.assertAllEqual( self.assertAllEqual(
example.features.feature['image/object/class/text'] example.features.feature['image/object/class/text']
.bytes_list.value, ['hyena']) .bytes_list.value, [b'hyena'])
self.assertAllClose( self.assertAllClose(
example.features.feature['image/class/label'] example.features.feature['image/class/label']
.int64_list.value, [5]) .int64_list.value, [5])
self.assertAllEqual( self.assertAllEqual(
example.features.feature['image/class/text'] example.features.feature['image/class/text']
.bytes_list.value, ['hyena']) .bytes_list.value, [b'hyena'])
# Check other essential attributes. # Check other essential attributes.
self.assertAllEqual( self.assertAllEqual(
...@@ -254,7 +254,7 @@ class GenerateEmbeddingData(tf.test.TestCase): ...@@ -254,7 +254,7 @@ class GenerateEmbeddingData(tf.test.TestCase):
example.features.feature['image/width'].int64_list.value, [600]) example.features.feature['image/width'].int64_list.value, [600])
self.assertAllEqual( self.assertAllEqual(
example.features.feature['image/source_id'].bytes_list.value, example.features.feature['image/source_id'].bytes_list.value,
['image_id']) [b'image_id'])
self.assertTrue( self.assertTrue(
example.features.feature['image/encoded'].bytes_list.value) example.features.feature['image/encoded'].bytes_list.value)
...@@ -271,7 +271,7 @@ class GenerateEmbeddingData(tf.test.TestCase): ...@@ -271,7 +271,7 @@ class GenerateEmbeddingData(tf.test.TestCase):
.int64_list.value, [5]) .int64_list.value, [5])
self.assertAllEqual(tf.train.Example.FromString( self.assertAllEqual(tf.train.Example.FromString(
generated_example).features.feature['image/object/class/text'] generated_example).features.feature['image/object/class/text']
.bytes_list.value, ['hyena']) .bytes_list.value, [b'hyena'])
output = inference_fn.process(generated_example) output = inference_fn.process(generated_example)
output_example = output[0] output_example = output[0]
self.assert_expected_example(output_example) self.assert_expected_example(output_example)
...@@ -307,7 +307,7 @@ class GenerateEmbeddingData(tf.test.TestCase): ...@@ -307,7 +307,7 @@ class GenerateEmbeddingData(tf.test.TestCase):
.feature['image/object/class/label'].int64_list.value, [5]) .feature['image/object/class/label'].int64_list.value, [5])
self.assertAllEqual( self.assertAllEqual(
tf.train.Example.FromString(generated_example).features tf.train.Example.FromString(generated_example).features
.feature['image/object/class/text'].bytes_list.value, ['hyena']) .feature['image/object/class/text'].bytes_list.value, [b'hyena'])
output = inference_fn.process(generated_example) output = inference_fn.process(generated_example)
output_example = output[0] output_example = output[0]
self.assert_expected_example(output_example, botk=True) self.assert_expected_example(output_example, botk=True)
......
...@@ -288,7 +288,7 @@ class SeqExampleUtilTest(tf.test.TestCase): ...@@ -288,7 +288,7 @@ class SeqExampleUtilTest(tf.test.TestCase):
[0.75, 1.], [0.75, 1.],
seq_feature_dict['region/bbox/xmax'].feature[0].float_list.value[:]) seq_feature_dict['region/bbox/xmax'].feature[0].float_list.value[:])
self.assertAllEqual( self.assertAllEqual(
['cat', 'frog'], [b'cat', b'frog'],
seq_feature_dict['region/label/string'].feature[0].bytes_list.value[:]) seq_feature_dict['region/label/string'].feature[0].bytes_list.value[:])
self.assertAllClose( self.assertAllClose(
[0.], [0.],
...@@ -332,7 +332,7 @@ class SeqExampleUtilTest(tf.test.TestCase): ...@@ -332,7 +332,7 @@ class SeqExampleUtilTest(tf.test.TestCase):
[0.75], [0.75],
seq_feature_dict['region/bbox/xmax'].feature[1].float_list.value[:]) seq_feature_dict['region/bbox/xmax'].feature[1].float_list.value[:])
self.assertAllEqual( self.assertAllEqual(
['cat'], [b'cat'],
seq_feature_dict['region/label/string'].feature[1].bytes_list.value[:]) seq_feature_dict['region/label/string'].feature[1].bytes_list.value[:])
self.assertAllClose( self.assertAllClose(
[], [],
......
...@@ -42,7 +42,7 @@ class OpenOutputTfrecordsTests(tf.test.TestCase): ...@@ -42,7 +42,7 @@ class OpenOutputTfrecordsTests(tf.test.TestCase):
tf_record_path = '{}-{:05d}-of-00010'.format( tf_record_path = '{}-{:05d}-of-00010'.format(
os.path.join(tf.test.get_temp_dir(), 'test.tfrec'), idx) os.path.join(tf.test.get_temp_dir(), 'test.tfrec'), idx)
records = list(tf.python_io.tf_record_iterator(tf_record_path)) records = list(tf.python_io.tf_record_iterator(tf_record_path))
self.assertAllEqual(records, ['test_{}'.format(idx)]) self.assertAllEqual(records, ['test_{}'.format(idx).encode('utf-8')])
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -419,7 +419,7 @@ class ExportTfliteGraphTest(tf.test.TestCase): ...@@ -419,7 +419,7 @@ class ExportTfliteGraphTest(tf.test.TestCase):
tflite_graph_file = self._export_graph_with_postprocessing_op( tflite_graph_file = self._export_graph_with_postprocessing_op(
pipeline_config) pipeline_config)
self.assertTrue(os.path.exists(tflite_graph_file)) self.assertTrue(os.path.exists(tflite_graph_file))
mock_get.assert_called_once() self.assertEqual(1, mock_get.call_count)
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -336,7 +336,7 @@ def load_fine_tune_checkpoint( ...@@ -336,7 +336,7 @@ def load_fine_tune_checkpoint(
labels) labels)
strategy = tf.compat.v2.distribute.get_strategy() strategy = tf.compat.v2.distribute.get_strategy()
strategy.run( strategy.experimental_run_v2(
_dummy_computation_fn, args=( _dummy_computation_fn, args=(
features, features,
labels, labels,
...@@ -562,7 +562,7 @@ def train_loop( ...@@ -562,7 +562,7 @@ def train_loop(
def _sample_and_train(strategy, train_step_fn, data_iterator): def _sample_and_train(strategy, train_step_fn, data_iterator):
features, labels = data_iterator.next() features, labels = data_iterator.next()
per_replica_losses = strategy.run( per_replica_losses = strategy.experimental_run_v2(
train_step_fn, args=(features, labels)) train_step_fn, args=(features, labels))
# TODO(anjalisridhar): explore if it is safe to remove the # TODO(anjalisridhar): explore if it is safe to remove the
## num_replicas scaling of the loss and switch this to a ReduceOp.Mean ## num_replicas scaling of the loss and switch this to a ReduceOp.Mean
......
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for bidirectional feature pyramid generators."""
import unittest
from absl.testing import parameterized
import tensorflow.compat.v1 as tf
from google.protobuf import text_format
from object_detection.builders import hyperparams_builder
from object_detection.models import bidirectional_feature_pyramid_generators as bifpn_generators
from object_detection.protos import hyperparams_pb2
from object_detection.utils import test_case
from object_detection.utils import test_utils
from object_detection.utils import tf_version
@parameterized.parameters({'bifpn_num_iterations': 2},
{'bifpn_num_iterations': 8})
@unittest.skipIf(tf_version.is_tf1(), 'Skipping TF2.X only test.')
class BiFPNFeaturePyramidGeneratorTest(test_case.TestCase):
def _build_conv_hyperparams(self):
conv_hyperparams = hyperparams_pb2.Hyperparams()
conv_hyperparams_text_proto = """
regularizer {
l2_regularizer {
}
}
initializer {
truncated_normal_initializer {
}
}
force_use_bias: true
"""
text_format.Merge(conv_hyperparams_text_proto, conv_hyperparams)
return hyperparams_builder.KerasLayerHyperparams(conv_hyperparams)
def test_get_expected_feature_map_shapes(self, bifpn_num_iterations):
with test_utils.GraphContextOrNone() as g:
image_features = [
('block3', tf.random_uniform([4, 16, 16, 256], dtype=tf.float32)),
('block4', tf.random_uniform([4, 8, 8, 256], dtype=tf.float32)),
('block5', tf.random_uniform([4, 4, 4, 256], dtype=tf.float32))
]
bifpn_generator = bifpn_generators.KerasBiFpnFeatureMaps(
bifpn_num_iterations=bifpn_num_iterations,
bifpn_num_filters=128,
fpn_min_level=3,
fpn_max_level=7,
input_max_level=5,
is_training=True,
conv_hyperparams=self._build_conv_hyperparams(),
freeze_batchnorm=False)
def graph_fn():
feature_maps = bifpn_generator(image_features)
return feature_maps
expected_feature_map_shapes = {
'{}_dn_lvl_3'.format(bifpn_num_iterations): (4, 16, 16, 128),
'{}_up_lvl_4'.format(bifpn_num_iterations): (4, 8, 8, 128),
'{}_up_lvl_5'.format(bifpn_num_iterations): (4, 4, 4, 128),
'{}_up_lvl_6'.format(bifpn_num_iterations): (4, 2, 2, 128),
'{}_up_lvl_7'.format(bifpn_num_iterations): (4, 1, 1, 128)}
out_feature_maps = self.execute(graph_fn, [], g)
out_feature_map_shapes = dict(
(key, value.shape) for key, value in out_feature_maps.items())
self.assertDictEqual(expected_feature_map_shapes, out_feature_map_shapes)
def test_get_expected_variable_names(self, bifpn_num_iterations):
with test_utils.GraphContextOrNone() as g:
image_features = [
('block3', tf.random_uniform([4, 16, 16, 256], dtype=tf.float32)),
('block4', tf.random_uniform([4, 8, 8, 256], dtype=tf.float32)),
('block5', tf.random_uniform([4, 4, 4, 256], dtype=tf.float32))
]
bifpn_generator = bifpn_generators.KerasBiFpnFeatureMaps(
bifpn_num_iterations=bifpn_num_iterations,
bifpn_num_filters=128,
fpn_min_level=3,
fpn_max_level=7,
input_max_level=5,
is_training=True,
conv_hyperparams=self._build_conv_hyperparams(),
freeze_batchnorm=False,
name='bifpn')
def graph_fn():
return bifpn_generator(image_features)
self.execute(graph_fn, [], g)
expected_variables = [
'bifpn/node_00/0_up_lvl_6/input_0_up_lvl_5/1x1_pre_sample/conv/bias',
'bifpn/node_00/0_up_lvl_6/input_0_up_lvl_5/1x1_pre_sample/conv/kernel',
'bifpn/node_03/1_dn_lvl_5/input_0_up_lvl_5/1x1_pre_sample/conv/bias',
'bifpn/node_03/1_dn_lvl_5/input_0_up_lvl_5/1x1_pre_sample/conv/kernel',
'bifpn/node_04/1_dn_lvl_4/input_0_up_lvl_4/1x1_pre_sample/conv/bias',
'bifpn/node_04/1_dn_lvl_4/input_0_up_lvl_4/1x1_pre_sample/conv/kernel',
'bifpn/node_05/1_dn_lvl_3/input_0_up_lvl_3/1x1_pre_sample/conv/bias',
'bifpn/node_05/1_dn_lvl_3/input_0_up_lvl_3/1x1_pre_sample/conv/kernel',
'bifpn/node_06/1_up_lvl_4/input_0_up_lvl_4/1x1_pre_sample/conv/bias',
'bifpn/node_06/1_up_lvl_4/input_0_up_lvl_4/1x1_pre_sample/conv/kernel',
'bifpn/node_07/1_up_lvl_5/input_0_up_lvl_5/1x1_pre_sample/conv/bias',
'bifpn/node_07/1_up_lvl_5/input_0_up_lvl_5/1x1_pre_sample/conv/kernel']
expected_node_variable_patterns = [
['bifpn/node_{:02}/{}_dn_lvl_6/combine/bifpn_combine_weights',
'bifpn/node_{:02}/{}_dn_lvl_6/post_combine/separable_conv/bias',
'bifpn/node_{:02}/{}_dn_lvl_6/post_combine/separable_conv/depthwise_kernel',
'bifpn/node_{:02}/{}_dn_lvl_6/post_combine/separable_conv/pointwise_kernel'],
['bifpn/node_{:02}/{}_dn_lvl_5/combine/bifpn_combine_weights',
'bifpn/node_{:02}/{}_dn_lvl_5/post_combine/separable_conv/bias',
'bifpn/node_{:02}/{}_dn_lvl_5/post_combine/separable_conv/depthwise_kernel',
'bifpn/node_{:02}/{}_dn_lvl_5/post_combine/separable_conv/pointwise_kernel'],
['bifpn/node_{:02}/{}_dn_lvl_4/combine/bifpn_combine_weights',
'bifpn/node_{:02}/{}_dn_lvl_4/post_combine/separable_conv/bias',
'bifpn/node_{:02}/{}_dn_lvl_4/post_combine/separable_conv/depthwise_kernel',
'bifpn/node_{:02}/{}_dn_lvl_4/post_combine/separable_conv/pointwise_kernel'],
['bifpn/node_{:02}/{}_dn_lvl_3/combine/bifpn_combine_weights',
'bifpn/node_{:02}/{}_dn_lvl_3/post_combine/separable_conv/bias',
'bifpn/node_{:02}/{}_dn_lvl_3/post_combine/separable_conv/depthwise_kernel',
'bifpn/node_{:02}/{}_dn_lvl_3/post_combine/separable_conv/pointwise_kernel'],
['bifpn/node_{:02}/{}_up_lvl_4/combine/bifpn_combine_weights',
'bifpn/node_{:02}/{}_up_lvl_4/post_combine/separable_conv/bias',
'bifpn/node_{:02}/{}_up_lvl_4/post_combine/separable_conv/depthwise_kernel',
'bifpn/node_{:02}/{}_up_lvl_4/post_combine/separable_conv/pointwise_kernel'],
['bifpn/node_{:02}/{}_up_lvl_5/combine/bifpn_combine_weights',
'bifpn/node_{:02}/{}_up_lvl_5/post_combine/separable_conv/bias',
'bifpn/node_{:02}/{}_up_lvl_5/post_combine/separable_conv/depthwise_kernel',
'bifpn/node_{:02}/{}_up_lvl_5/post_combine/separable_conv/pointwise_kernel'],
['bifpn/node_{:02}/{}_up_lvl_6/combine/bifpn_combine_weights',
'bifpn/node_{:02}/{}_up_lvl_6/post_combine/separable_conv/bias',
'bifpn/node_{:02}/{}_up_lvl_6/post_combine/separable_conv/depthwise_kernel',
'bifpn/node_{:02}/{}_up_lvl_6/post_combine/separable_conv/pointwise_kernel'],
['bifpn/node_{:02}/{}_up_lvl_7/combine/bifpn_combine_weights',
'bifpn/node_{:02}/{}_up_lvl_7/post_combine/separable_conv/bias',
'bifpn/node_{:02}/{}_up_lvl_7/post_combine/separable_conv/depthwise_kernel',
'bifpn/node_{:02}/{}_up_lvl_7/post_combine/separable_conv/pointwise_kernel']]
node_i = 2
for iter_i in range(1, bifpn_num_iterations+1):
for node_variable_patterns in expected_node_variable_patterns:
for pattern in node_variable_patterns:
expected_variables.append(pattern.format(node_i, iter_i))
node_i += 1
expected_variables = set(expected_variables)
actual_variable_set = set(
[var.name.split(':')[0] for var in bifpn_generator.variables])
self.assertSetEqual(expected_variables, actual_variable_set)
# TODO(aom): Tests for create_bifpn_combine_op.
if __name__ == '__main__':
tf.test.main()
...@@ -43,6 +43,15 @@ def _get_padding_for_kernel_size(kernel_size): ...@@ -43,6 +43,15 @@ def _get_padding_for_kernel_size(kernel_size):
kernel_size)) kernel_size))
def batchnorm():
try:
return tf.keras.layers.experimental.SyncBatchNormalization(
name='batchnorm', epsilon=1e-5, momentum=0.1)
except AttributeError:
return tf.keras.layers.BatchNormalization(
name='batchnorm', epsilon=1e-5, momentum=0.1, fused=BATCH_NORM_FUSED)
class ConvolutionalBlock(tf.keras.layers.Layer): class ConvolutionalBlock(tf.keras.layers.Layer):
"""Block that aggregates Convolution + Norm layer + ReLU.""" """Block that aggregates Convolution + Norm layer + ReLU."""
...@@ -73,8 +82,7 @@ class ConvolutionalBlock(tf.keras.layers.Layer): ...@@ -73,8 +82,7 @@ class ConvolutionalBlock(tf.keras.layers.Layer):
filters=out_channels, kernel_size=kernel_size, use_bias=False, filters=out_channels, kernel_size=kernel_size, use_bias=False,
strides=stride, padding=padding) strides=stride, padding=padding)
self.norm = tf.keras.layers.experimental.SyncBatchNormalization( self.norm = batchnorm()
name='batchnorm', epsilon=1e-5, momentum=0.1)
if relu: if relu:
self.relu = tf.keras.layers.ReLU() self.relu = tf.keras.layers.ReLU()
...@@ -124,8 +132,7 @@ class ResidualBlock(tf.keras.layers.Layer): ...@@ -124,8 +132,7 @@ class ResidualBlock(tf.keras.layers.Layer):
self.conv = tf.keras.layers.Conv2D( self.conv = tf.keras.layers.Conv2D(
filters=out_channels, kernel_size=kernel_size, use_bias=False, filters=out_channels, kernel_size=kernel_size, use_bias=False,
strides=1, padding=padding) strides=1, padding=padding)
self.norm = tf.keras.layers.experimental.SyncBatchNormalization( self.norm = batchnorm()
name='batchnorm', epsilon=1e-5, momentum=0.1)
if skip_conv: if skip_conv:
self.skip = SkipConvolution(out_channels=out_channels, self.skip = SkipConvolution(out_channels=out_channels,
......
...@@ -54,8 +54,8 @@ def extract_submodel(model, inputs, outputs, name=None): ...@@ -54,8 +54,8 @@ def extract_submodel(model, inputs, outputs, name=None):
for layer in model.layers: for layer in model.layers:
layer_output = layer.output layer_output = layer.output
layer_inputs = layer.input layer_inputs = layer.input
output_to_layer[layer_output.ref()] = layer output_to_layer[layer_output.experimental_ref()] = layer
output_to_layer_input[layer_output.ref()] = layer_inputs output_to_layer_input[layer_output.experimental_ref()] = layer_inputs
model_inputs_dict = {} model_inputs_dict = {}
memoized_results = {} memoized_results = {}
...@@ -63,21 +63,22 @@ def extract_submodel(model, inputs, outputs, name=None): ...@@ -63,21 +63,22 @@ def extract_submodel(model, inputs, outputs, name=None):
# Relies on recursion, very low limit in python # Relies on recursion, very low limit in python
def _recurse_in_model(tensor): def _recurse_in_model(tensor):
"""Walk the existing model recursively to copy a submodel.""" """Walk the existing model recursively to copy a submodel."""
if tensor.ref() in memoized_results: if tensor.experimental_ref() in memoized_results:
return memoized_results[tensor.ref()] return memoized_results[tensor.experimental_ref()]
if (tensor.ref() == inputs.ref()) or ( if (tensor.experimental_ref() == inputs.experimental_ref()) or (
isinstance(inputs, list) and tensor in inputs): isinstance(inputs, list) and tensor in inputs):
if tensor.ref() not in model_inputs_dict: if tensor.experimental_ref() not in model_inputs_dict:
model_inputs_dict[tensor.ref()] = tf.keras.layers.Input(tensor=tensor) model_inputs_dict[tensor.experimental_ref()] = tf.keras.layers.Input(
out = model_inputs_dict[tensor.ref()] tensor=tensor)
out = model_inputs_dict[tensor.experimental_ref()]
else: else:
cur_inputs = output_to_layer_input[tensor.ref()] cur_inputs = output_to_layer_input[tensor.experimental_ref()]
cur_layer = output_to_layer[tensor.ref()] cur_layer = output_to_layer[tensor.experimental_ref()]
if isinstance(cur_inputs, list): if isinstance(cur_inputs, list):
out = cur_layer([_recurse_in_model(inp) for inp in cur_inputs]) out = cur_layer([_recurse_in_model(inp) for inp in cur_inputs])
else: else:
out = cur_layer(_recurse_in_model(cur_inputs)) out = cur_layer(_recurse_in_model(cur_inputs))
memoized_results[tensor.ref()] = out memoized_results[tensor.experimental_ref()] = out
return out return out
if isinstance(outputs, list): if isinstance(outputs, list):
...@@ -86,8 +87,10 @@ def extract_submodel(model, inputs, outputs, name=None): ...@@ -86,8 +87,10 @@ def extract_submodel(model, inputs, outputs, name=None):
model_outputs = _recurse_in_model(outputs) model_outputs = _recurse_in_model(outputs)
if isinstance(inputs, list): if isinstance(inputs, list):
model_inputs = [model_inputs_dict[tensor.ref()] for tensor in inputs] model_inputs = [
model_inputs_dict[tensor.experimental_ref()] for tensor in inputs
]
else: else:
model_inputs = model_inputs_dict[inputs.ref()] model_inputs = model_inputs_dict[inputs.experimental_ref()]
return tf.keras.Model(inputs=model_inputs, outputs=model_outputs, name=name) return tf.keras.Model(inputs=model_inputs, outputs=model_outputs, name=name)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment