Unverified Commit a1c4b630 authored by Matt's avatar Matt Committed by GitHub
Browse files

TF CI fix for Segformer (#24426)

Fix segformer so compilation can figure out the channel dim
parent 754f61ca
......@@ -710,21 +710,20 @@ class TFSegformerDecodeHead(TFSegformerPreTrainedModel):
self.config = config
def call(self, encoder_hidden_states, training: bool = False):
batch_size = shape_list(encoder_hidden_states[-1])[0]
all_hidden_states = ()
for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.mlps):
if self.config.reshape_last_stage is False and len(shape_list(encoder_hidden_state)) == 3:
height = tf.math.sqrt(tf.cast(shape_list(encoder_hidden_state)[1], tf.float32))
height = width = tf.cast(height, tf.int32)
encoder_hidden_state = tf.reshape(encoder_hidden_state, (batch_size, height, width, -1))
channel_dim = shape_list(encoder_hidden_state)[-1]
encoder_hidden_state = tf.reshape(encoder_hidden_state, (-1, height, width, channel_dim))
# unify channel dimension
encoder_hidden_state = tf.transpose(encoder_hidden_state, perm=[0, 2, 3, 1])
height = shape_list(encoder_hidden_state)[1]
width = shape_list(encoder_hidden_state)[2]
height, width = shape_list(encoder_hidden_state)[1:3]
encoder_hidden_state = mlp(encoder_hidden_state)
encoder_hidden_state = tf.reshape(encoder_hidden_state, (batch_size, height, width, -1))
channel_dim = shape_list(encoder_hidden_state)[-1]
encoder_hidden_state = tf.reshape(encoder_hidden_state, (-1, height, width, channel_dim))
# upsample
temp_state = tf.transpose(encoder_hidden_states[0], perm=[0, 2, 3, 1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment