Unverified Commit 62b05b69 authored by p-mishra1's avatar p-mishra1 Committed by GitHub
Browse files

Add type annotations for segformer classes (#16099)

parent 9042dfe3
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
import collections import collections
import math import math
from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -373,11 +374,11 @@ class SegformerEncoder(nn.Module): ...@@ -373,11 +374,11 @@ class SegformerEncoder(nn.Module):
def forward( def forward(
self, self,
pixel_values, pixel_values: torch.FloatTensor,
output_attentions=False, output_attentions: Optional[bool] = False,
output_hidden_states=False, output_hidden_states: Optional[bool] = False,
return_dict=True, return_dict: Optional[bool] = True,
): ) -> Union[Tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
...@@ -501,7 +502,13 @@ class SegformerModel(SegformerPreTrainedModel): ...@@ -501,7 +502,13 @@ class SegformerModel(SegformerPreTrainedModel):
modality="vision", modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE, expected_output=_EXPECTED_OUTPUT_SHAPE,
) )
def forward(self, pixel_values, output_attentions=None, output_hidden_states=None, return_dict=None): def forward(
self,
pixel_values: torch.FloatTensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
...@@ -556,12 +563,12 @@ class SegformerForImageClassification(SegformerPreTrainedModel): ...@@ -556,12 +563,12 @@ class SegformerForImageClassification(SegformerPreTrainedModel):
) )
def forward( def forward(
self, self,
pixel_values=None, pixel_values: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, SequenceClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ..., Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
...@@ -715,12 +722,12 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel): ...@@ -715,12 +722,12 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
@replace_return_docstrings(output_type=SemanticSegmentationModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=SemanticSegmentationModelOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
pixel_values, pixel_values: torch.FloatTensor,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, SemanticSegmentationModelOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ..., Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment