"git@developer.sourcefind.cn:OpenDAS/fastmoe.git" did not exist on "6cb6bbe429bf71f24e11e806716aa9b8f40d7558"
Commit 9abd85f2 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Fixes typos and use string for endpoints key.

PiperOrigin-RevId: 363736548
parent dc04b2b0
...@@ -163,7 +163,7 @@ class ResNet3D(tf.keras.Model): ...@@ -163,7 +163,7 @@ class ResNet3D(tf.keras.Model):
block_repeats=resnet_spec[2], block_repeats=resnet_spec[2],
use_self_gating=use_self_gating[i] if use_self_gating else False, use_self_gating=use_self_gating[i] if use_self_gating else False,
name='block_group_l{}'.format(i + 2)) name='block_group_l{}'.format(i + 2))
endpoints[i + 2] = x endpoints[str(i + 2)] = x
self._output_specs = {l: endpoints[l].get_shape() for l in endpoints} self._output_specs = {l: endpoints[l].get_shape() for l in endpoints}
......
...@@ -47,16 +47,16 @@ class ResNet3DTest(parameterized.TestCase, tf.test.TestCase): ...@@ -47,16 +47,16 @@ class ResNet3DTest(parameterized.TestCase, tf.test.TestCase):
self.assertAllEqual([ self.assertAllEqual([
1, 2, input_size / 2**2, input_size / 2**2, 64 * endpoint_filter_scale 1, 2, input_size / 2**2, input_size / 2**2, 64 * endpoint_filter_scale
], endpoints[2].shape.as_list()) ], endpoints['2'].shape.as_list())
self.assertAllEqual([ self.assertAllEqual([
1, 2, input_size / 2**3, input_size / 2**3, 128 * endpoint_filter_scale 1, 2, input_size / 2**3, input_size / 2**3, 128 * endpoint_filter_scale
], endpoints[3].shape.as_list()) ], endpoints['3'].shape.as_list())
self.assertAllEqual([ self.assertAllEqual([
1, 2, input_size / 2**4, input_size / 2**4, 256 * endpoint_filter_scale 1, 2, input_size / 2**4, input_size / 2**4, 256 * endpoint_filter_scale
], endpoints[4].shape.as_list()) ], endpoints['4'].shape.as_list())
self.assertAllEqual([ self.assertAllEqual([
1, 2, input_size / 2**5, input_size / 2**5, 512 * endpoint_filter_scale 1, 2, input_size / 2**5, input_size / 2**5, 512 * endpoint_filter_scale
], endpoints[5].shape.as_list()) ], endpoints['5'].shape.as_list())
def test_serialize_deserialize(self): def test_serialize_deserialize(self):
# Create a network object that sets all of its config options. # Create a network object that sets all of its config options.
......
...@@ -210,7 +210,7 @@ class BottleneckBlock3D(tf.keras.layers.Layer): ...@@ -210,7 +210,7 @@ class BottleneckBlock3D(tf.keras.layers.Layer):
'temporal_kernel_size': self._temporal_kernel_size, 'temporal_kernel_size': self._temporal_kernel_size,
'temporal_strides': self._temporal_strides, 'temporal_strides': self._temporal_strides,
'spatial_strides': self._spatial_strides, 'spatial_strides': self._spatial_strides,
'use_projection': self._use_projection, 'use_self_gating': self._use_self_gating,
'kernel_initializer': self._kernel_initializer, 'kernel_initializer': self._kernel_initializer,
'kernel_regularizer': self._kernel_regularizer, 'kernel_regularizer': self._kernel_regularizer,
'bias_regularizer': self._bias_regularizer, 'bias_regularizer': self._bias_regularizer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment