"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "c89180a9de1fc2e98654812fd1c233c3bc6a8d43"
Unverified Commit ea07064a authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

Returning outputs only when asked for for MaskFormer. (#15936)

* Returning outputs only when asked for for MaskFormer.

* Adding `output_auxiliary_logits` to the config.
parent b19f3e69
......@@ -69,6 +69,8 @@ class MaskFormerConfig(PretrainedConfig):
The weight for the cross entropy loss.
mask_weight (`float`, *optional*, defaults to 20.0):
The weight for the mask loss.
output_auxiliary_logits (`bool`, *optional*):
Should the model output its `auxiliary_logits` or not.
Raises:
`ValueError`:
......@@ -109,6 +111,7 @@ class MaskFormerConfig(PretrainedConfig):
dice_weight: float = 1.0,
cross_entropy_weight: float = 1.0,
mask_weight: float = 20.0,
output_auxiliary_logits: Optional[bool] = None,
**kwargs,
):
if backbone_config is None:
......@@ -156,6 +159,7 @@ class MaskFormerConfig(PretrainedConfig):
self.mask_weight = mask_weight
self.use_auxiliary_loss = use_auxiliary_loss
self.no_object_weight = no_object_weight
self.output_auxiliary_logits = output_auxiliary_logits
self.num_attention_heads = self.decoder_config.encoder_attention_heads
self.num_hidden_layers = self.decoder_config.num_hidden_layers
......
......@@ -2313,9 +2313,16 @@ class MaskFormerModel(MaskFormerPreTrainedModel):
)
queries = transformer_module_output.last_hidden_state
encoder_hidden_states = pixel_level_module_output.encoder_hidden_states if output_hidden_states else ()
pixel_decoder_hidden_states = pixel_level_module_output.decoder_hidden_states if output_hidden_states else ()
transformer_decoder_hidden_states = transformer_module_output.hidden_states if output_hidden_states else ()
if output_hidden_states:
encoder_hidden_states = pixel_level_module_output.encoder_hidden_states
pixel_decoder_hidden_states = pixel_level_module_output.decoder_hidden_states
transformer_decoder_hidden_states = transformer_module_output.hidden_states
hidden_states = encoder_hidden_states + pixel_decoder_hidden_states + transformer_decoder_hidden_states
else:
encoder_hidden_states = None
pixel_decoder_hidden_states = None
transformer_decoder_hidden_states = None
hidden_states = None
output = MaskFormerModelOutput(
encoder_last_hidden_state=image_features,
......@@ -2324,7 +2331,7 @@ class MaskFormerModel(MaskFormerPreTrainedModel):
encoder_hidden_states=encoder_hidden_states,
pixel_decoder_hidden_states=pixel_decoder_hidden_states,
transformer_decoder_hidden_states=transformer_decoder_hidden_states,
hidden_states=encoder_hidden_states + pixel_decoder_hidden_states + transformer_decoder_hidden_states,
hidden_states=hidden_states,
attentions=transformer_module_output.attentions,
)
......@@ -2421,6 +2428,7 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
mask_labels: Optional[Tensor] = None,
class_labels: Optional[Tensor] = None,
pixel_mask: Optional[Tensor] = None,
output_auxiliary_logits: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_attentions: Optional[bool] = None,
return_dict: Optional[bool] = None,
......@@ -2484,6 +2492,12 @@ class MaskFormerForInstanceSegmentation(MaskFormerPreTrainedModel):
)
loss = self.get_loss(loss_dict)
output_auxiliary_logits = (
self.config.output_auxiliary_logits if output_auxiliary_logits is None else output_auxiliary_logits
)
if not output_auxiliary_logits:
auxiliary_logits = None
output = MaskFormerForInstanceSegmentationOutput(
loss=loss,
**outputs,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment