Unverified Commit 62b05b69 authored by p-mishra1's avatar p-mishra1 Committed by GitHub
Browse files

Add type annotations for segformer classes (#16099)

parent 9042dfe3
......@@ -17,6 +17,7 @@
import collections
import math
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
......@@ -373,11 +374,11 @@ class SegformerEncoder(nn.Module):
def forward(
self,
pixel_values,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
pixel_values: torch.FloatTensor,
output_attentions: Optional[bool] = False,
output_hidden_states: Optional[bool] = False,
return_dict: Optional[bool] = True,
) -> Union[Tuple, BaseModelOutput]:
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
......@@ -501,7 +502,13 @@ class SegformerModel(SegformerPreTrainedModel):
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(self, pixel_values, output_attentions=None, output_hidden_states=None, return_dict=None):
def forward(
self,
pixel_values: torch.FloatTensor,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutput]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
......@@ -556,12 +563,12 @@ class SegformerForImageClassification(SegformerPreTrainedModel):
)
def forward(
self,
pixel_values=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
pixel_values: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SequenceClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
......@@ -715,12 +722,12 @@ class SegformerForSemanticSegmentation(SegformerPreTrainedModel):
@replace_return_docstrings(output_type=SemanticSegmentationModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
self,
pixel_values,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
pixel_values: torch.FloatTensor,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, SemanticSegmentationModelOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, height, width)`, *optional*):
Ground truth semantic segmentation maps for computing the loss. Indices should be in `[0, ...,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment