"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "1367142afd363e2799e3299b9bbf14fcb5e848c0"
Unverified Commit 89be34c3 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Fix SegformerForImageClassification (#15895)



* Fix reshape

* Apply suggestion from code review
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent 130b9878
...@@ -579,8 +579,11 @@ class SegformerForImageClassification(SegformerPreTrainedModel): ...@@ -579,8 +579,11 @@ class SegformerForImageClassification(SegformerPreTrainedModel):
sequence_output = outputs[0] sequence_output = outputs[0]
# reshape last hidden states to (batch_size, height*width, hidden_size) # convert last hidden states to (batch_size, height*width, hidden_size)
batch_size = sequence_output.shape[0] batch_size = sequence_output.shape[0]
if self.config.reshape_last_stage:
# (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels)
sequence_output = sequence_output.permute(0, 2, 3, 1)
sequence_output = sequence_output.reshape(batch_size, -1, self.config.hidden_sizes[-1]) sequence_output = sequence_output.reshape(batch_size, -1, self.config.hidden_sizes[-1])
# global average pooling # global average pooling
...@@ -660,10 +663,19 @@ class SegformerDecodeHead(SegformerPreTrainedModel): ...@@ -660,10 +663,19 @@ class SegformerDecodeHead(SegformerPreTrainedModel):
self.dropout = nn.Dropout(config.classifier_dropout_prob) self.dropout = nn.Dropout(config.classifier_dropout_prob)
self.classifier = nn.Conv2d(config.decoder_hidden_size, config.num_labels, kernel_size=1) self.classifier = nn.Conv2d(config.decoder_hidden_size, config.num_labels, kernel_size=1)
self.config = config
def forward(self, encoder_hidden_states): def forward(self, encoder_hidden_states):
batch_size, _, _, _ = encoder_hidden_states[-1].shape batch_size = encoder_hidden_states[-1].shape[0]
all_hidden_states = () all_hidden_states = ()
for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.linear_c): for encoder_hidden_state, mlp in zip(encoder_hidden_states, self.linear_c):
if self.config.reshape_last_stage is False and encoder_hidden_state.ndim == 3:
height = width = int(math.sqrt(encoder_hidden_state.shape[-1]))
encoder_hidden_state = (
encoder_hidden_state.reshape(batch_size, height, width, -1).permute(0, 3, 1, 2).contiguous()
)
# unify channel dimension # unify channel dimension
height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3] height, width = encoder_hidden_state.shape[2], encoder_hidden_state.shape[3]
encoder_hidden_state = mlp(encoder_hidden_state) encoder_hidden_state = mlp(encoder_hidden_state)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment