"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "36f183ebab1261b388739d628aaa0b4150068df0"
Unverified Commit a1c4b630 authored by Matt's avatar Matt Committed by GitHub
Browse files

TF CI fix for Segformer (#24426)

Fix segformer so compilation can figure out the channel dim
parent 754f61ca
...@@ -710,21 +710,20 @@ class TFSegformerDecodeHead(TFSegformerPreTrainedModel): ...@@ -710,21 +710,20 @@ class TFSegformerDecodeHead(TFSegformerPreTrainedModel):
self.config = config self.config = config
def call(self, encoder_hidden_states, training: bool = False): def call(self, encoder_hidden_states, training: bool = False):
batch_size = shape_list(encoder_hidden_states[-1])[0]
all_hidden_states = () all_hidden_states = ()
for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.mlps): for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.mlps):
if self.config.reshape_last_stage is False and len(shape_list(encoder_hidden_state)) == 3: if self.config.reshape_last_stage is False and len(shape_list(encoder_hidden_state)) == 3:
height = tf.math.sqrt(tf.cast(shape_list(encoder_hidden_state)[1], tf.float32)) height = tf.math.sqrt(tf.cast(shape_list(encoder_hidden_state)[1], tf.float32))
height = width = tf.cast(height, tf.int32) height = width = tf.cast(height, tf.int32)
encoder_hidden_state = tf.reshape(encoder_hidden_state, (batch_size, height, width, -1)) channel_dim = shape_list(encoder_hidden_state)[-1]
encoder_hidden_state = tf.reshape(encoder_hidden_state, (-1, height, width, channel_dim))
# unify channel dimension # unify channel dimension
encoder_hidden_state = tf.transpose(encoder_hidden_state, perm=[0, 2, 3, 1]) encoder_hidden_state = tf.transpose(encoder_hidden_state, perm=[0, 2, 3, 1])
height = shape_list(encoder_hidden_state)[1] height, width = shape_list(encoder_hidden_state)[1:3]
width = shape_list(encoder_hidden_state)[2]
encoder_hidden_state = mlp(encoder_hidden_state) encoder_hidden_state = mlp(encoder_hidden_state)
encoder_hidden_state = tf.reshape(encoder_hidden_state, (batch_size, height, width, -1)) channel_dim = shape_list(encoder_hidden_state)[-1]
encoder_hidden_state = tf.reshape(encoder_hidden_state, (-1, height, width, channel_dim))
# upsample # upsample
temp_state = tf.transpose(encoder_hidden_states[0], perm=[0, 2, 3, 1]) temp_state = tf.transpose(encoder_hidden_states[0], perm=[0, 2, 3, 1])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment