Commit 9abd85f2 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Fixes typos and use string for endpoints key.

PiperOrigin-RevId: 363736548
parent dc04b2b0
......@@ -163,7 +163,7 @@ class ResNet3D(tf.keras.Model):
block_repeats=resnet_spec[2],
use_self_gating=use_self_gating[i] if use_self_gating else False,
name='block_group_l{}'.format(i + 2))
endpoints[i + 2] = x
endpoints[str(i + 2)] = x
self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}
......
......@@ -47,16 +47,16 @@ class ResNet3DTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllEqual([
1, 2, input_size / 2**2, input_size / 2**2, 64 * endpoint_filter_scale
], endpoints[2].shape.as_list())
], endpoints['2'].shape.as_list())
self.assertAllEqual([
1, 2, input_size / 2**3, input_size / 2**3, 128 * endpoint_filter_scale
], endpoints[3].shape.as_list())
], endpoints['3'].shape.as_list())
self.assertAllEqual([
1, 2, input_size / 2**4, input_size / 2**4, 256 * endpoint_filter_scale
], endpoints[4].shape.as_list())
], endpoints['4'].shape.as_list())
self.assertAllEqual([
1, 2, input_size / 2**5, input_size / 2**5, 512 * endpoint_filter_scale
], endpoints[5].shape.as_list())
], endpoints['5'].shape.as_list())
def test_serialize_deserialize(self):
# Create a network object that sets all of its config options.
......
......@@ -210,7 +210,7 @@ class BottleneckBlock3D(tf.keras.layers.Layer):
'temporal_kernel_size': self._temporal_kernel_size,
'temporal_strides': self._temporal_strides,
'spatial_strides': self._spatial_strides,
'use_projection': self._use_projection,
'use_self_gating': self._use_self_gating,
'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment