Unverified Commit 5493c10e authored by Hyeonsoo Lee's avatar Hyeonsoo Lee Committed by GitHub
Browse files

Add type hints for PoolFormer in Pytorch (#16121)

parent 6c2f3ed7
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import collections.abc import collections.abc
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -379,7 +379,12 @@ class PoolFormerModel(PoolFormerPreTrainedModel): ...@@ -379,7 +379,12 @@ class PoolFormerModel(PoolFormerPreTrainedModel):
modality="vision", modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE, expected_output=_EXPECTED_OUTPUT_SHAPE,
) )
def forward(self, pixel_values=None, output_hidden_states=None, return_dict=None): def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, PoolFormerModelOutput]:
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
...@@ -446,11 +451,11 @@ class PoolFormerForImageClassification(PoolFormerPreTrainedModel): ...@@ -446,11 +451,11 @@ class PoolFormerForImageClassification(PoolFormerPreTrainedModel):
) )
def forward( def forward(
self, self,
pixel_values=None, pixel_values: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, PoolFormerClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ..., Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment