Commit 8124922d authored by Gunho Park's avatar Gunho Park
Browse files

Internal change

parent 06000e73
...@@ -102,14 +102,17 @@ class BASNet_Decoder(tf.keras.Model): ...@@ -102,14 +102,17 @@ class BASNet_Decoder(tf.keras.Model):
# Get input feature pyramid from backbone. # Get input feature pyramid from backbone.
inputs = self._build_input_pyramid(input_specs) inputs = self._build_input_pyramid(input_specs)
levels = sorted(inputs.keys(), reverse=True)
sup = {} sup = {}
for i, spec in enumerate(BASNET_DECODER_SPECS): for i, spec in enumerate(BASNET_DECODER_SPECS):
if i == 0: if i == 0:
x = inputs['5'] # Bridge input #x = inputs['5'] # Bridge input
x = inputs[levels[0]] # Bridge input
# str(levels[-1]) ??
else: else:
x = tf.keras.layers.Concatenate(axis=-1)([x, inputs[str(6-i)]]) x = tf.keras.layers.Concatenate(axis=-1)([x, inputs[levels[i-1]]])
for j in range(3): for j in range(3):
x = nn_blocks.ConvBlock( x = nn_blocks.ConvBlock(
...@@ -141,7 +144,7 @@ class BASNet_Decoder(tf.keras.Model): ...@@ -141,7 +144,7 @@ class BASNet_Decoder(tf.keras.Model):
output = tf.keras.layers.Activation( output = tf.keras.layers.Activation(
activation='sigmoid' activation='sigmoid'
)(output) )(output)
sup[str(i+1)] = output sup[str(i)] = output
if i != 0: if i != 0:
x = tf.keras.layers.UpSampling2D( x = tf.keras.layers.UpSampling2D(
size=2, size=2,
...@@ -150,7 +153,7 @@ class BASNet_Decoder(tf.keras.Model): ...@@ -150,7 +153,7 @@ class BASNet_Decoder(tf.keras.Model):
self._output_specs = { self._output_specs = {
str(order): sup[str(order)].get_shape() str(order): sup[str(order)].get_shape()
for order in range(1, 7) for order in range(0, len(BASNET_DECODER_SPECS))
} }
super(BASNet_Decoder, self).__init__(inputs=inputs, outputs=sup, **kwargs) super(BASNet_Decoder, self).__init__(inputs=inputs, outputs=sup, **kwargs)
......
...@@ -56,8 +56,11 @@ class BASNetModel(tf.keras.Model): ...@@ -56,8 +56,11 @@ class BASNetModel(tf.keras.Model):
if self.decoder: if self.decoder:
features = self.decoder(features) features = self.decoder(features)
levels = sorted(features.keys())
new_key = str(len(levels))
if self.refinement: if self.refinement:
features['ref'] = self.refinement(features['7']) features[new_key] = self.refinement(features[levels[-1]])
return features return features
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment