Commit 09c0b474 authored by Jiageng Zhang's avatar Jiageng Zhang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 480388656
parent 09aeecd6
...@@ -195,11 +195,11 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -195,11 +195,11 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
training = tf.keras.backend.learning_phase() training = tf.keras.backend.learning_phase()
result = [] result = []
for i, layer in enumerate(self.aspp_layers): for i, layer in enumerate(self.aspp_layers):
result.append(tf.cast(layer(inputs, training=training), inputs.dtype)) x = layer(inputs, training=training)
# Apply resize layer to the end of the last set of layers.
if i == len(self.aspp_layers) - 1: if i == len(self.aspp_layers) - 1:
input_shape = inputs.get_shape().as_list() x = tf.image.resize(tf.cast(x, tf.float32), tf.shape(inputs)[1:3])
height, width = input_shape[1:3] result.append(tf.cast(x, inputs.dtype))
result[-1] = tf.image.resize(result[-1], [height, width])
result = tf.concat(result, axis=-1) result = tf.concat(result, axis=-1)
result = self.projection(result, training=training) result = self.projection(result, training=training)
return result return result
......
...@@ -1146,6 +1146,8 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -1146,6 +1146,8 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
self._bn_axis = 1 self._bn_axis = 1
def build(self, input_shape): def build(self, input_shape):
height = input_shape[1]
width = input_shape[2]
channels = input_shape[3] channels = input_shape[3]
self.aspp_layers = [] self.aspp_layers = []
...@@ -1218,6 +1220,9 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -1218,6 +1220,9 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
self.aspp_layers.append(pooling + [conv2, norm2]) self.aspp_layers.append(pooling + [conv2, norm2])
self._resizing_layer = tf.keras.layers.Resizing(
height, width, interpolation=self._interpolation, dtype=tf.float32)
self._projection = [ self._projection = [
tf.keras.layers.Conv2D( tf.keras.layers.Conv2D(
filters=self._output_channels, filters=self._output_channels,
...@@ -1249,9 +1254,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -1249,9 +1254,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
# Apply resize layer to the end of the last set of layers. # Apply resize layer to the end of the last set of layers.
if i == len(self.aspp_layers) - 1: if i == len(self.aspp_layers) - 1:
input_shape = inputs.get_shape().as_list() x = self._resizing_layer(x)
height, width = input_shape[1:3]
x = tf.image.resize(x, [height, width], self._interpolation)
result.append(tf.cast(x, inputs.dtype)) result.append(tf.cast(x, inputs.dtype))
x = self._concat_layer(result) x = self._concat_layer(result)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment