Commit 8124922d authored by Gunho Park's avatar Gunho Park
Browse files

Internal change

parent 06000e73
......@@ -102,14 +102,17 @@ class BASNet_Decoder(tf.keras.Model):
# Get input feature pyramid from backbone.
inputs = self._build_input_pyramid(input_specs)
levels = sorted(inputs.keys(), reverse=True)
sup = {}
for i, spec in enumerate(BASNET_DECODER_SPECS):
if i == 0:
x = inputs['5'] # Bridge input
#x = inputs['5'] # Bridge input
x = inputs[levels[0]] # Bridge input
# str(levels[-1]) ??
else:
x = tf.keras.layers.Concatenate(axis=-1)([x, inputs[str(6-i)]])
x = tf.keras.layers.Concatenate(axis=-1)([x, inputs[levels[i-1]]])
for j in range(3):
x = nn_blocks.ConvBlock(
......@@ -141,7 +144,7 @@ class BASNet_Decoder(tf.keras.Model):
output = tf.keras.layers.Activation(
activation='sigmoid'
)(output)
sup[str(i+1)] = output
sup[str(i)] = output
if i != 0:
x = tf.keras.layers.UpSampling2D(
size=2,
......@@ -150,7 +153,7 @@ class BASNet_Decoder(tf.keras.Model):
self._output_specs = {
str(order): sup[str(order)].get_shape()
for order in range(1, 7)
for order in range(0, len(BASNET_DECODER_SPECS))
}
super(BASNet_Decoder, self).__init__(inputs=inputs, outputs=sup, **kwargs)
......
......@@ -56,8 +56,11 @@ class BASNetModel(tf.keras.Model):
if self.decoder:
features = self.decoder(features)
levels = sorted(features.keys())
new_key = str(len(levels))
if self.refinement:
features['ref'] = self.refinement(features['7'])
features[new_key] = self.refinement(features[levels[-1]])
return features
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment