Unverified Commit 5493c10e authored by Hyeonsoo Lee's avatar Hyeonsoo Lee Committed by GitHub
Browse files

Add type hints for PoolFormer in Pytorch (#16121)

parent 6c2f3ed7
......@@ -17,7 +17,7 @@
import collections.abc
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
......@@ -379,7 +379,12 @@ class PoolFormerModel(PoolFormerPreTrainedModel):
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(self, pixel_values=None, output_hidden_states=None, return_dict=None):
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, PoolFormerModelOutput]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
......@@ -446,11 +451,11 @@ class PoolFormerForImageClassification(PoolFormerPreTrainedModel):
)
def forward(
self,
pixel_values=None,
labels=None,
output_hidden_states=None,
return_dict=None,
):
pixel_values: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, PoolFormerClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment