Unverified Commit e70abdad authored by Fx039482's avatar Fx039482 Committed by GitHub
Browse files

Update modeling_cvt.py (#17846)

As shown in the colab notebook I added the missing type hints for " CvtForImageClassification
CvtModel
"
parent 1a7ef334
......@@ -17,7 +17,7 @@
import collections.abc
from dataclasses import dataclass
from typing import Optional, Tuple
from typing import Optional, Tuple, Union
import torch
import torch.utils.checkpoint
......@@ -604,7 +604,13 @@ class CvtModel(CvtPreTrainedModel):
modality="vision",
expected_output=_EXPECTED_OUTPUT_SHAPE,
)
def forward(self, pixel_values=None, output_hidden_states=None, return_dict=None):
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithCLSToken]:
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
......@@ -662,11 +668,11 @@ class CvtForImageClassification(CvtPreTrainedModel):
)
def forward(
self,
pixel_values=None,
labels=None,
output_hidden_states=None,
return_dict=None,
):
pixel_values: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, ImageClassifierOutputWithNoAttention]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment