Unverified Commit e70abdad authored by Fx039482's avatar Fx039482 Committed by GitHub
Browse files

Update modeling_cvt.py (#17846)

As shown in the colab notebook I added the missing type hints for " CvtForImageClassification
CvtModel
"
parent 1a7ef334
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import collections.abc import collections.abc
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -604,7 +604,13 @@ class CvtModel(CvtPreTrainedModel): ...@@ -604,7 +604,13 @@ class CvtModel(CvtPreTrainedModel):
modality="vision", modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE, expected_output=_EXPECTED_OUTPUT_SHAPE,
) )
def forward(self, pixel_values=None, output_hidden_states=None, return_dict=None): def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithCLSToken]:
output_hidden_states = ( output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
) )
...@@ -662,11 +668,11 @@ class CvtForImageClassification(CvtPreTrainedModel): ...@@ -662,11 +668,11 @@ class CvtForImageClassification(CvtPreTrainedModel):
) )
def forward( def forward(
self, self,
pixel_values=None, pixel_values: Optional[torch.Tensor] = None,
labels=None, labels: Optional[torch.Tensor] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, ImageClassifierOutputWithNoAttention]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ..., Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment