"megatron/git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "5fb0ff9ae3c024d9c60906f0e1707610df491e72"
Commit bb35d42e authored by Abdullah Rashwan's avatar Abdullah Rashwan Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 352689864
parent 811ca229
...@@ -142,7 +142,10 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -142,7 +142,10 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
epsilon=self.batchnorm_epsilon), epsilon=self.batchnorm_epsilon),
tf.keras.layers.Activation(self.activation), tf.keras.layers.Activation(self.activation),
tf.keras.layers.experimental.preprocessing.Resizing( tf.keras.layers.experimental.preprocessing.Resizing(
height, width, interpolation=self.interpolation) height,
width,
interpolation=self.interpolation,
dtype=tf.float32)
])) ]))
self.aspp_layers.append(pool_sequential) self.aspp_layers.append(pool_sequential)
...@@ -165,7 +168,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -165,7 +168,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
training = tf.keras.backend.learning_phase() training = tf.keras.backend.learning_phase()
result = [] result = []
for layer in self.aspp_layers: for layer in self.aspp_layers:
result.append(layer(inputs, training=training)) result.append(tf.cast(layer(inputs, training=training), inputs.dtype))
result = tf.concat(result, axis=-1) result = tf.concat(result, axis=-1)
result = self.projection(result, training=training) result = self.projection(result, training=training)
return result return result
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment