Commit c5f11e3f authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 346346141
parent 3b488d59
...@@ -178,7 +178,7 @@ class SpineNet(tf.keras.Model): ...@@ -178,7 +178,7 @@ class SpineNet(tf.keras.Model):
net = self._build_stem(inputs=inputs) net = self._build_stem(inputs=inputs)
net = self._build_scale_permuted_network( net = self._build_scale_permuted_network(
net=net, input_width=input_specs.shape[1]) net=net, input_width=input_specs.shape[2])
endpoints = self._build_endpoints(net=net) endpoints = self._build_endpoints(net=net)
self._output_specs = {l: endpoints[l].get_shape() for l in endpoints} self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}
......
...@@ -60,6 +60,38 @@ class SpineNetTest(parameterized.TestCase, tf.test.TestCase): ...@@ -60,6 +60,38 @@ class SpineNetTest(parameterized.TestCase, tf.test.TestCase):
[1, input_size / 2**l, input_size / 2**l, endpoints_num_filters], [1, input_size / 2**l, input_size / 2**l, endpoints_num_filters],
endpoints[str(l)].shape.as_list()) endpoints[str(l)].shape.as_list())
@parameterized.parameters(
((128, 128), (128, 128)),
((128, 128), (256, 256)),
((640, 640), (896, 1664)),
)
def test_load_from_different_input_specs(self, input_size_1, input_size_2):
"""Test loading checkpoints with different input size."""
def build_spinenet(input_size):
tf.keras.backend.set_image_data_format('channels_last')
input_specs = tf.keras.layers.InputSpec(
shape=[None, input_size[0], input_size[1], 3])
model = spinenet.SpineNet(
input_specs=input_specs,
min_level=3,
max_level=7,
endpoints_num_filters=384,
resample_alpha=1.0,
block_repeats=2,
filter_size_scale=0.5)
return model
model_1 = build_spinenet(input_size_1)
model_2 = build_spinenet(input_size_2)
ckpt_1 = tf.train.Checkpoint(backbone=model_1)
ckpt_2 = tf.train.Checkpoint(backbone=model_2)
ckpt_path = self.get_temp_dir() + '/ckpt'
ckpt_1.write(ckpt_path)
ckpt_2.restore(ckpt_path).expect_partial()
def test_serialize_deserialize(self): def test_serialize_deserialize(self):
# Create a network object that sets all of its config options. # Create a network object that sets all of its config options.
kwargs = dict( kwargs = dict(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment