Commit 78d99a22 authored by Fan Yang's avatar Fan Yang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 413811970
parent b7f0d1ae
...@@ -155,6 +155,7 @@ class SegmentationHead3D(tf.keras.layers.Layer): ...@@ -155,6 +155,7 @@ class SegmentationHead3D(tf.keras.layers.Layer):
- key: A `str` of the level of the multilevel features. - key: A `str` of the level of the multilevel features.
- values: A `tf.Tensor` of the feature map tensors, whose shape is - values: A `tf.Tensor` of the feature map tensors, whose shape is
[batch, height_l, width_l, channels]. [batch, height_l, width_l, channels].
The first is backbone endpoints, and the second is decoder endpoints.
Returns: Returns:
segmentation prediction mask: A `tf.Tensor` of the segmentation mask segmentation prediction mask: A `tf.Tensor` of the segmentation mask
scores predicted from input features. scores predicted from input features.
......
...@@ -149,7 +149,8 @@ class ASPP(tf.keras.layers.Layer): ...@@ -149,7 +149,8 @@ class ASPP(tf.keras.layers.Layer):
return outputs if self._config_dict['output_tensor'] else {level: outputs} return outputs if self._config_dict['output_tensor'] else {level: outputs}
def get_config(self) -> Mapping[str, Any]: def get_config(self) -> Mapping[str, Any]:
return self._config_dict base_config = super().get_config()
return dict(list(base_config.items()) + list(self._config_dict.items()))
@classmethod @classmethod
def from_config(cls, config, custom_objects=None): def from_config(cls, config, custom_objects=None):
......
...@@ -74,7 +74,10 @@ class ASPPTest(parameterized.TestCase, tf.test.TestCase): ...@@ -74,7 +74,10 @@ class ASPPTest(parameterized.TestCase, tf.test.TestCase):
dropout_rate=0.2, dropout_rate=0.2,
use_depthwise_convolution='false', use_depthwise_convolution='false',
spp_layer_version='v1', spp_layer_version='v1',
output_tensor=False) output_tensor=False,
dtype='float32',
name='aspp',
trainable=True)
network = aspp.ASPP(**kwargs) network = aspp.ASPP(**kwargs)
expected_config = dict(kwargs) expected_config = dict(kwargs)
......
...@@ -133,6 +133,10 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase): ...@@ -133,6 +133,10 @@ class FactoryTest(tf.test.TestCase, parameterized.TestCase):
network_config = network.get_config() network_config = network.get_config()
factory_network_config = factory_network.get_config() factory_network_config = factory_network.get_config()
# Due to calling `super().get_config()` in aspp layer, everything but the
# the name of two layer instances are the same, so we force equal name so it
# will not give false alarm.
factory_network_config['name'] = network_config['name']
self.assertEqual(network_config, factory_network_config) self.assertEqual(network_config, factory_network_config)
......
...@@ -202,7 +202,7 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -202,7 +202,7 @@ class SegmentationHead(tf.keras.layers.Layer):
kernel_regularizer=self._config_dict['kernel_regularizer'], kernel_regularizer=self._config_dict['kernel_regularizer'],
bias_regularizer=self._config_dict['bias_regularizer']) bias_regularizer=self._config_dict['bias_regularizer'])
super(SegmentationHead, self).build(input_shape) super().build(input_shape)
def call(self, inputs: Tuple[Union[tf.Tensor, Mapping[str, tf.Tensor]], def call(self, inputs: Tuple[Union[tf.Tensor, Mapping[str, tf.Tensor]],
Union[tf.Tensor, Mapping[str, tf.Tensor]]]): Union[tf.Tensor, Mapping[str, tf.Tensor]]]):
...@@ -220,6 +220,7 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -220,6 +220,7 @@ class SegmentationHead(tf.keras.layers.Layer):
- key: A `str` of the level of the multilevel features. - key: A `str` of the level of the multilevel features.
- values: A `tf.Tensor` of the feature map tensors, whose shape is - values: A `tf.Tensor` of the feature map tensors, whose shape is
[batch, height_l, width_l, channels]. [batch, height_l, width_l, channels].
The first is backbone endpoints, and the second is decoder endpoints.
Returns: Returns:
segmentation prediction mask: A `tf.Tensor` of the segmentation mask segmentation prediction mask: A `tf.Tensor` of the segmentation mask
scores predicted from input features. scores predicted from input features.
...@@ -261,7 +262,8 @@ class SegmentationHead(tf.keras.layers.Layer): ...@@ -261,7 +262,8 @@ class SegmentationHead(tf.keras.layers.Layer):
return self._classifier(x) return self._classifier(x)
def get_config(self): def get_config(self):
return self._config_dict base_config = super().get_config()
return dict(list(base_config.items()) + list(self._config_dict.items()))
@classmethod @classmethod
def from_config(cls, config): def from_config(cls, config):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment