"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "42e00cf9e1969973a563db2900ed86bbf58dbc71"
Unverified Commit 35ecf99c authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Revert changes in logit size for semantic segmentation models (#15722)

* Revert changes in logit size for semantic segmentation models

* Address review comments
parent d1fcc90a
...@@ -822,8 +822,17 @@ class SemanticSegmentationModelOutput(ModelOutput): ...@@ -822,8 +822,17 @@ class SemanticSegmentationModelOutput(ModelOutput):
Args: Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Classification (or regression if config.num_labels==1) loss. Classification (or regression if config.num_labels==1) loss.
logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels, height, width)`): logits (`torch.FloatTensor` of shape `(batch_size, config.num_labels, logits_height, logits_width)`):
Classification scores for each pixel. Classification scores for each pixel.
<Tip warning={true}>
The logits returned do not necessarily have the same size as the `pixel_values` passed as inputs. This is
to avoid doing two interpolations and lose some quality when a user needs to resize the logits to the
original image size as post-processing. You should always check your logits shape and resize as needed.
</Tip>
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
shape `(batch_size, patch_size, hidden_size)`. shape `(batch_size, patch_size, hidden_size)`.
......
...@@ -93,10 +93,6 @@ class BeitConfig(PretrainedConfig): ...@@ -93,10 +93,6 @@ class BeitConfig(PretrainedConfig):
Whether to concatenate the output of the auxiliary head with the input before the classification layer. Whether to concatenate the output of the auxiliary head with the input before the classification layer.
semantic_loss_ignore_index (`int`, *optional*, defaults to 255): semantic_loss_ignore_index (`int`, *optional*, defaults to 255):
The index that is ignored by the loss function of the semantic segmentation model. The index that is ignored by the loss function of the semantic segmentation model.
legacy_output (`bool`, *optional*, defaults to `False`):
Whether to return the legacy outputs or not (with logits of shape `height / 4 , width / 4`)
This argument is only present for backward compatibility reasons and will be removed in v5 of Transformers.
Example: Example:
...@@ -145,7 +141,6 @@ class BeitConfig(PretrainedConfig): ...@@ -145,7 +141,6 @@ class BeitConfig(PretrainedConfig):
auxiliary_num_convs=1, auxiliary_num_convs=1,
auxiliary_concat_input=False, auxiliary_concat_input=False,
semantic_loss_ignore_index=255, semantic_loss_ignore_index=255,
legacy_output=False,
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -181,4 +176,3 @@ class BeitConfig(PretrainedConfig): ...@@ -181,4 +176,3 @@ class BeitConfig(PretrainedConfig):
self.auxiliary_num_convs = auxiliary_num_convs self.auxiliary_num_convs = auxiliary_num_convs
self.auxiliary_concat_input = auxiliary_concat_input self.auxiliary_concat_input = auxiliary_concat_input
self.semantic_loss_ignore_index = semantic_loss_ignore_index self.semantic_loss_ignore_index = semantic_loss_ignore_index
self.legacy_output = legacy_output
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
import collections.abc import collections.abc
import math import math
import warnings
from dataclasses import dataclass from dataclasses import dataclass
import torch import torch
...@@ -1121,8 +1120,11 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel): ...@@ -1121,8 +1120,11 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
def compute_loss(self, upsampled_logits, auxiliary_logits, labels): def compute_loss(self, logits, auxiliary_logits, labels):
# upsample logits to the images' original size # upsample logits to the images' original size
upsampled_logits = nn.functional.interpolate(
logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
)
if auxiliary_logits is not None: if auxiliary_logits is not None:
upsampled_auxiliary_logits = nn.functional.interpolate( upsampled_auxiliary_logits = nn.functional.interpolate(
auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
...@@ -1145,17 +1147,11 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel): ...@@ -1145,17 +1147,11 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
legacy_output=None,
): ):
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ..., Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy). config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
legacy_output (`bool`, *optional*):
Whether to return the legacy outputs or not (with logits of shape `height / 4 , width / 4`). Will default
to `self.config.legacy_output`.
This argument is only present for backward compatibility reasons and will be removed in v5 of Transformers.
Returns: Returns:
...@@ -1181,14 +1177,6 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel): ...@@ -1181,14 +1177,6 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
legacy_output = legacy_output if legacy_output is not None else self.config.legacy_output
if not legacy_output:
warnings.warn(
"The output of this model has changed in v4.17.0 and the logits now have the same size as the inputs. "
"You can activate the previous behavior by passing `legacy_output=True` to this call or the "
"configuration of this model (only until v5, then that argument will be removed).",
FutureWarning,
)
outputs = self.beit( outputs = self.beit(
pixel_values, pixel_values,
...@@ -1216,10 +1204,6 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel): ...@@ -1216,10 +1204,6 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
logits = self.decode_head(features) logits = self.decode_head(features)
upsampled_logits = nn.functional.interpolate(
logits, size=pixel_values.shape[-2:], mode="bilinear", align_corners=False
)
auxiliary_logits = None auxiliary_logits = None
if self.auxiliary_head is not None: if self.auxiliary_head is not None:
auxiliary_logits = self.auxiliary_head(features) auxiliary_logits = self.auxiliary_head(features)
...@@ -1229,26 +1213,18 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel): ...@@ -1229,26 +1213,18 @@ class BeitForSemanticSegmentation(BeitPreTrainedModel):
if self.config.num_labels == 1: if self.config.num_labels == 1:
raise ValueError("The number of labels should be greater than one") raise ValueError("The number of labels should be greater than one")
else: else:
loss = self.compute_loss(upsampled_logits, auxiliary_logits, labels) loss = self.compute_loss(logits, auxiliary_logits, labels)
if not return_dict: if not return_dict:
if output_hidden_states: if output_hidden_states:
output = (logits if legacy_output else upsampled_logits,) + outputs[2:] output = (logits,) + outputs[2:]
else: else:
output = (logits if legacy_output else upsampled_logits,) + outputs[3:] output = (logits,) + outputs[3:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
if legacy_output: return SemanticSegmentationModelOutput(
return SequenceClassifierOutput( loss=loss,
loss=loss, logits=logits,
logits=logits, hidden_states=outputs.hidden_states if output_hidden_states else None,
hidden_states=outputs.hidden_states if output_hidden_states else None, attentions=outputs.attentions,
attentions=outputs.attentions, )
)
else:
return SemanticSegmentationModelOutput(
loss=loss,
logits=upsampled_logits,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
)
...@@ -83,10 +83,6 @@ class SegformerConfig(PretrainedConfig): ...@@ -83,10 +83,6 @@ class SegformerConfig(PretrainedConfig):
required for the semantic segmentation model. required for the semantic segmentation model.
semantic_loss_ignore_index (`int`, *optional*, defaults to 255): semantic_loss_ignore_index (`int`, *optional*, defaults to 255):
The index that is ignored by the loss function of the semantic segmentation model. The index that is ignored by the loss function of the semantic segmentation model.
legacy_output (`bool`, *optional*, defaults to `False`):
Whether to return the legacy outputs or not (with logits of shape `height / 4 , width / 4`)
This argument is only present for backward compatibility reasons and will be removed in v5 of Transformers.
Example: Example:
...@@ -128,7 +124,6 @@ class SegformerConfig(PretrainedConfig): ...@@ -128,7 +124,6 @@ class SegformerConfig(PretrainedConfig):
is_encoder_decoder=False, is_encoder_decoder=False,
reshape_last_stage=True, reshape_last_stage=True,
semantic_loss_ignore_index=255, semantic_loss_ignore_index=255,
legacy_output=False,
**kwargs **kwargs
): ):
super().__init__(**kwargs) super().__init__(**kwargs)
...@@ -154,4 +149,3 @@ class SegformerConfig(PretrainedConfig): ...@@ -154,4 +149,3 @@ class SegformerConfig(PretrainedConfig):
self.decoder_hidden_size = decoder_hidden_size self.decoder_hidden_size = decoder_hidden_size
self.reshape_last_stage = reshape_last_stage self.reshape_last_stage = reshape_last_stage
self.semantic_loss_ignore_index = semantic_loss_ignore_index self.semantic_loss_ignore_index = semantic_loss_ignore_index
self.legacy_output = legacy_output
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
import collections import collections
import math import math
import warnings
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -697,17 +696,11 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel): ...@@ -697,17 +696,11 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
legacy_output=None,
): ):
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ..., Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy). config.num_labels - 1]`. If `config.num_labels > 1`, a classification loss is computed (Cross-Entropy).
legacy_output (`bool`, *optional*):
Whether to return the legacy outputs or not (with logits of shape `height / 4 , width / 4`). Will default
to `self.config.legacy_output`.
This argument is only present for backward compatibility reasons and will be removed in v5 of Transformers.
Returns: Returns:
...@@ -732,14 +725,6 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel): ...@@ -732,14 +725,6 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
legacy_output = legacy_output if legacy_output is not None else self.config.legacy_output
if not legacy_output:
warnings.warn(
"The output of this model has changed in v4.17.0 and the logits now have the same size as the inputs. "
"You can activate the previous behavior by passing `legacy_output=True` to this call or the "
"configuration of this model (only until v5, then that argument will be removed).",
FutureWarning,
)
outputs = self.segformer( outputs = self.segformer(
pixel_values, pixel_values,
...@@ -752,37 +737,28 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel): ...@@ -752,37 +737,28 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
logits = self.decode_head(encoder_hidden_states) logits = self.decode_head(encoder_hidden_states)
upsampled_logits = nn.functional.interpolate(
logits, size=pixel_values.shape[-2:], mode="bilinear", align_corners=False
)
loss = None loss = None
if labels is not None: if labels is not None:
if self.config.num_labels == 1: if self.config.num_labels == 1:
raise ValueError("The number of labels should be greater than one") raise ValueError("The number of labels should be greater than one")
else: else:
# upsample logits to the images' original size # upsample logits to the images' original size
upsampled_logits = nn.functional.interpolate(
logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
)
loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index) loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
loss = loss_fct(upsampled_logits, labels) loss = loss_fct(upsampled_logits, labels)
if not return_dict: if not return_dict:
if output_hidden_states: if output_hidden_states:
output = (logits if legacy_output else upsampled_logits,) + outputs[1:] output = (logits,) + outputs[1:]
else: else:
output = (logits if legacy_output else upsampled_logits,) + outputs[2:] output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output return ((loss,) + output) if loss is not None else output
if legacy_output: return SemanticSegmentationModelOutput(
return SequenceClassifierOutput( loss=loss,
loss=loss, logits=logits,
logits=logits, hidden_states=outputs.hidden_states if output_hidden_states else None,
hidden_states=outputs.hidden_states if output_hidden_states else None, attentions=outputs.attentions,
attentions=outputs.attentions, )
)
else:
return SemanticSegmentationModelOutput(
loss=loss,
logits=upsampled_logits,
hidden_states=outputs.hidden_states if output_hidden_states else None,
attentions=outputs.attentions,
)
...@@ -162,11 +162,11 @@ class BeitModelTester: ...@@ -162,11 +162,11 @@ class BeitModelTester:
model.eval() model.eval()
result = model(pixel_values) result = model(pixel_values)
self.parent.assertEqual( self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.num_labels, self.image_size, self.image_size) result.logits.shape, (self.batch_size, self.num_labels, self.image_size * 2, self.image_size * 2)
) )
result = model(pixel_values, labels=pixel_labels) result = model(pixel_values, labels=pixel_labels)
self.parent.assertEqual( self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.num_labels, self.image_size, self.image_size) result.logits.shape, (self.batch_size, self.num_labels, self.image_size * 2, self.image_size * 2)
) )
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
...@@ -533,14 +533,14 @@ class BeitModelIntegrationTest(unittest.TestCase): ...@@ -533,14 +533,14 @@ class BeitModelIntegrationTest(unittest.TestCase):
logits = outputs.logits logits = outputs.logits
# verify the logits # verify the logits
expected_shape = torch.Size((1, 150, 640, 640)) expected_shape = torch.Size((1, 150, 160, 160))
self.assertEqual(logits.shape, expected_shape) self.assertEqual(logits.shape, expected_shape)
expected_slice = torch.tensor( expected_slice = torch.tensor(
[ [
[[-4.9225, -4.9225, -4.6066], [-4.9225, -4.9225, -4.6066], [-4.6675, -4.6675, -4.3617]], [[-4.9225, -2.3954, -3.0522], [-2.8822, -1.0046, -1.7561], [-2.9549, -1.3228, -2.1347]],
[[-5.8168, -5.8168, -5.5163], [-5.8168, -5.8168, -5.5163], [-5.5728, -5.5728, -5.2842]], [[-5.8168, -3.4129, -4.0778], [-3.8651, -2.2214, -3.0277], [-3.8356, -2.4643, -3.3535]],
[[-0.0078, -0.0078, 0.4926], [-0.0078, -0.0078, 0.4926], [0.3664, 0.3664, 0.8309]], [[-0.0078, 3.9952, 4.0754], [2.9856, 4.6944, 5.0035], [3.2413, 4.7813, 4.9969]],
] ]
).to(torch_device) ).to(torch_device)
......
...@@ -135,11 +135,11 @@ class SegformerModelTester: ...@@ -135,11 +135,11 @@ class SegformerModelTester:
model.eval() model.eval()
result = model(pixel_values) result = model(pixel_values)
self.parent.assertEqual( self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.num_labels, self.image_size, self.image_size) result.logits.shape, (self.batch_size, self.num_labels, self.image_size // 4, self.image_size // 4)
) )
result = model(pixel_values, labels=labels) result = model(pixel_values, labels=labels)
self.parent.assertEqual( self.parent.assertEqual(
result.logits.shape, (self.batch_size, self.num_labels, self.image_size, self.image_size) result.logits.shape, (self.batch_size, self.num_labels, self.image_size // 4, self.image_size // 4)
) )
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
...@@ -363,14 +363,14 @@ class SegformerModelIntegrationTest(unittest.TestCase): ...@@ -363,14 +363,14 @@ class SegformerModelIntegrationTest(unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
outputs = model(pixel_values) outputs = model(pixel_values)
expected_shape = torch.Size((1, model.config.num_labels, 512, 512)) expected_shape = torch.Size((1, model.config.num_labels, 128, 128))
self.assertEqual(outputs.logits.shape, expected_shape) self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor( expected_slice = torch.tensor(
[ [
[[-4.6309, -4.6309, -4.7425], [-4.6309, -4.6309, -4.7425], [-4.7011, -4.7011, -4.8136]], [[-4.6310, -5.5232, -6.2356], [-5.1921, -6.1444, -6.5996], [-5.4424, -6.2790, -6.7574]],
[[-12.1391, -12.1391, -12.2858], [-12.1391, -12.1391, -12.2858], [-12.2309, -12.2309, -12.3758]], [[-12.1391, -13.3122, -13.9554], [-12.8732, -13.9352, -14.3563], [-12.9438, -13.8226, -14.2513]],
[[-12.5134, -12.5134, -12.6328], [-12.5134, -12.5134, -12.6328], [-12.5576, -12.5576, -12.6865]], [[-12.5134, -13.4686, -14.4915], [-12.8669, -14.4343, -14.7758], [-13.2523, -14.5819, -15.0694]],
] ]
).to(torch_device) ).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3, :3], expected_slice, atol=1e-4)) self.assertTrue(torch.allclose(outputs.logits[0, :3, :3, :3], expected_slice, atol=1e-4))
...@@ -392,14 +392,14 @@ class SegformerModelIntegrationTest(unittest.TestCase): ...@@ -392,14 +392,14 @@ class SegformerModelIntegrationTest(unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
outputs = model(pixel_values) outputs = model(pixel_values)
expected_shape = torch.Size((1, model.config.num_labels, 512, 512)) expected_shape = torch.Size((1, model.config.num_labels, 128, 128))
self.assertEqual(outputs.logits.shape, expected_shape) self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor( expected_slice = torch.tensor(
[ [
[[-13.5729, -13.5729, -13.6149], [-13.5729, -13.5729, -13.6149], [-13.6697, -13.6697, -13.7224]], [[-13.5748, -13.9111, -12.6500], [-14.3500, -15.3683, -14.2328], [-14.7532, -16.0424, -15.6087]],
[[-17.1638, -17.1638, -17.0022], [-17.1638, -17.1638, -17.0022], [-17.1754, -17.1754, -17.0358]], [[-17.1651, -15.8725, -12.9653], [-17.2580, -17.3718, -14.8223], [-16.6058, -16.8783, -16.7452]],
[[-3.6452, -3.6452, -3.5670], [-3.6452, -3.6452, -3.5670], [-3.5744, -3.5744, -3.5079]], [[-3.6456, -3.0209, -1.4203], [-3.0797, -3.1959, -2.0000], [-1.8757, -1.9217, -1.6997]],
] ]
).to(torch_device) ).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3, :3, :3], expected_slice, atol=1e-1)) self.assertTrue(torch.allclose(outputs.logits[0, :3, :3, :3], expected_slice, atol=1e-1))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment