Unverified Commit 89be34c3 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Fix SegformerForImageClassification (#15895)



* Fix reshape

* Apply suggestion from code review
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent 130b9878
......@@ -579,8 +579,11 @@ class SegformerForImageClassification(SegformerPreTrainedModel):
sequence_output = outputs[0]
# reshape last hidden states to (batch_size, height*width, hidden_size)
# convert last hidden states to (batch_size, height*width, hidden_size)
batch_size = sequence_output.shape[0]
if self.config.reshape_last_stage:
# (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels)
sequence_output = sequence_output.permute(0, 2, 3, 1)
sequence_output = sequence_output.reshape(batch_size, -1, self.config.hidden_sizes[-1])
# global average pooling
......@@ -660,10 +663,19 @@ class SegformerDecodeHead(SegformerPreTrainedModel):
self.dropout = nn.Dropout(config.classifier_dropout_prob)
self.classifier = nn.Conv2d(config.decoder_hidden_size, config.num_labels, kernel_size=1)
self.config = config
def forward(self, encoder_hidden_states):
batch_size, _, _, _ = encoder_hidden_states[-1].shape
batch_size = encoder_hidden_states[-1].shape[0]
all_hidden_states = ()
for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.linear_c):
if self.config.reshape_last_stage is False and encoder_hidden_state.ndim == 3:
height = width = int(math.sqrt(encoder_hidden_state.shape[-1]))
encoder_hidden_state = (
encoder_hidden_state.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2).contiguous()
)
# unify channel dimension
height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3]
encoder_hidden_state = mlp(encoder_hidden_state)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment