"...git@developer.sourcefind.cn:modelzoo/qwen_lmdeploy.git" did not exist on "fe46dac2c2ea1a988929fba05e9d3d3c9b11dfd7"
Commit e5d7e4ff authored by Jiageng Zhang's avatar Jiageng Zhang Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 479358049
parent 3d29d12b
...@@ -87,8 +87,6 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -87,8 +87,6 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
self.use_depthwise_convolution = use_depthwise_convolution self.use_depthwise_convolution = use_depthwise_convolution
def build(self, input_shape): def build(self, input_shape):
height = input_shape[1]
width = input_shape[2]
channels = input_shape[3] channels = input_shape[3]
self.aspp_layers = [] self.aspp_layers = []
...@@ -171,12 +169,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -171,12 +169,7 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
axis=bn_axis, axis=bn_axis,
momentum=self.batchnorm_momentum, momentum=self.batchnorm_momentum,
epsilon=self.batchnorm_epsilon), epsilon=self.batchnorm_epsilon),
tf.keras.layers.Activation(self.activation), tf.keras.layers.Activation(self.activation)
tf.keras.layers.experimental.preprocessing.Resizing(
height,
width,
interpolation=self.interpolation,
dtype=tf.float32)
])) ]))
self.aspp_layers.append(pool_sequential) self.aspp_layers.append(pool_sequential)
...@@ -201,8 +194,12 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -201,8 +194,12 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
if training is None: if training is None:
training = tf.keras.backend.learning_phase() training = tf.keras.backend.learning_phase()
result = [] result = []
for layer in self.aspp_layers: for i, layer in enumerate(self.aspp_layers):
result.append(tf.cast(layer(inputs, training=training), inputs.dtype)) result.append(tf.cast(layer(inputs, training=training), inputs.dtype))
if i == len(self.aspp_layers) - 1:
input_shape = inputs.get_shape().as_list()
height, width = input_shape[1:3]
result[-1] = tf.image.resize(result[-1], [height, width])
result = tf.concat(result, axis=-1) result = tf.concat(result, axis=-1)
result = self.projection(result, training=training) result = self.projection(result, training=training)
return result return result
......
...@@ -1146,8 +1146,6 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -1146,8 +1146,6 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
self._bn_axis = 1 self._bn_axis = 1
def build(self, input_shape): def build(self, input_shape):
height = input_shape[1]
width = input_shape[2]
channels = input_shape[3] channels = input_shape[3]
self.aspp_layers = [] self.aspp_layers = []
...@@ -1220,9 +1218,6 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -1220,9 +1218,6 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
self.aspp_layers.append(pooling + [conv2, norm2]) self.aspp_layers.append(pooling + [conv2, norm2])
self._resizing_layer = tf.keras.layers.Resizing(
height, width, interpolation=self._interpolation, dtype=tf.float32)
self._projection = [ self._projection = [
tf.keras.layers.Conv2D( tf.keras.layers.Conv2D(
filters=self._output_channels, filters=self._output_channels,
...@@ -1254,7 +1249,9 @@ class SpatialPyramidPooling(tf.keras.layers.Layer): ...@@ -1254,7 +1249,9 @@ class SpatialPyramidPooling(tf.keras.layers.Layer):
# Apply resize layer to the end of the last set of layers. # Apply resize layer to the end of the last set of layers.
if i == len(self.aspp_layers) - 1: if i == len(self.aspp_layers) - 1:
x = self._resizing_layer(x) input_shape = inputs.get_shape().as_list()
height, width = input_shape[1:3]
x = tf.image.resize(x, [height, width], self._interpolation)
result.append(tf.cast(x, inputs.dtype)) result.append(tf.cast(x, inputs.dtype))
x = self._concat_layer(result) x = self._concat_layer(result)
......
...@@ -308,7 +308,6 @@ class SemanticSegmentationTask(base_task.Task): ...@@ -308,7 +308,6 @@ class SemanticSegmentationTask(base_task.Task):
self.iou_metric.update_state(labels, outputs['logits']) self.iou_metric.update_state(labels, outputs['logits'])
if metrics: if metrics:
self.process_metrics(metrics, labels, outputs) self.process_metrics(metrics, labels, outputs)
logs.update({m.name: m.result() for m in metrics})
return logs return logs
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment