"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "870ff9e1dab249e4ffd8363ce132aa5145c94604"
Unverified Commit f86235ad authored by Jacob Dineen's avatar Jacob Dineen Committed by GitHub
Browse files

Add type annotations for CLIP (torch) (#16059) (#16106)

* clip typhinting #16059

* removed optional type annotations for dataclass in CLIPOutput

* type annotation fixes per Rocket - Clip Torch
parent c1000e70
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional, Tuple from typing import Any, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -128,7 +128,7 @@ class CLIPVisionEmbeddings(nn.Module): ...@@ -128,7 +128,7 @@ class CLIPVisionEmbeddings(nn.Module):
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1))) self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
def forward(self, pixel_values): def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0] batch_size = pixel_values.shape[0]
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2) patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
...@@ -150,7 +150,12 @@ class CLIPTextEmbeddings(nn.Module): ...@@ -150,7 +150,12 @@ class CLIPTextEmbeddings(nn.Module):
# position_ids (1, len position emb) is contiguous in memory and exported when serialized # position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
def forward(self, input_ids=None, position_ids=None, inputs_embeds=None): def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
) -> torch.Tensor:
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
if position_ids is None: if position_ids is None:
...@@ -193,7 +198,7 @@ class CLIPAttention(nn.Module): ...@@ -193,7 +198,7 @@ class CLIPAttention(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None, causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False, output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
...@@ -272,7 +277,7 @@ class CLIPMLP(nn.Module): ...@@ -272,7 +277,7 @@ class CLIPMLP(nn.Module):
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.fc1(hidden_states) hidden_states = self.fc1(hidden_states)
hidden_states = self.activation_fn(hidden_states) hidden_states = self.activation_fn(hidden_states)
hidden_states = self.fc2(hidden_states) hidden_states = self.fc2(hidden_states)
...@@ -293,8 +298,8 @@ class CLIPEncoderLayer(nn.Module): ...@@ -293,8 +298,8 @@ class CLIPEncoderLayer(nn.Module):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
causal_attention_mask: torch.Tensor, causal_attention_mask: torch.Tensor,
output_attentions: bool = False, output_attentions: Optional[bool] = False,
): ) -> Tuple[torch.FloatTensor]:
""" """
Args: Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
...@@ -502,12 +507,12 @@ class CLIPEncoder(nn.Module): ...@@ -502,12 +507,12 @@ class CLIPEncoder(nn.Module):
def forward( def forward(
self, self,
inputs_embeds, inputs_embeds,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask=None, causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, BaseModelOutput]:
r""" r"""
Args: Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
...@@ -600,13 +605,13 @@ class CLIPTextTransformer(nn.Module): ...@@ -600,13 +605,13 @@ class CLIPTextTransformer(nn.Module):
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
position_ids=None, position_ids: Optional[torch.Tensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, BaseModelOutputWithPooling]:
r""" r"""
Returns: Returns:
...@@ -689,13 +694,13 @@ class CLIPTextModel(CLIPPreTrainedModel): ...@@ -689,13 +694,13 @@ class CLIPTextModel(CLIPPreTrainedModel):
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
position_ids=None, position_ids: Optional[torch.Tensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, BaseModelOutputWithPooling]:
r""" r"""
Returns: Returns:
...@@ -738,11 +743,11 @@ class CLIPVisionTransformer(nn.Module): ...@@ -738,11 +743,11 @@ class CLIPVisionTransformer(nn.Module):
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
def forward( def forward(
self, self,
pixel_values=None, pixel_values: Optional[torch.FloatTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, BaseModelOutputWithPooling]:
r""" r"""
Returns: Returns:
...@@ -798,11 +803,11 @@ class CLIPVisionModel(CLIPPreTrainedModel): ...@@ -798,11 +803,11 @@ class CLIPVisionModel(CLIPPreTrainedModel):
@replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig) @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPVisionConfig)
def forward( def forward(
self, self,
pixel_values=None, pixel_values: Optional[torch.FloatTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, BaseModelOutputWithPooling]:
r""" r"""
Returns: Returns:
...@@ -870,13 +875,13 @@ class CLIPModel(CLIPPreTrainedModel): ...@@ -870,13 +875,13 @@ class CLIPModel(CLIPPreTrainedModel):
@add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
def get_text_features( def get_text_features(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
position_ids=None, position_ids: Optional[torch.Tensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> torch.FloatTensor:
r""" r"""
Returns: Returns:
text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
...@@ -910,11 +915,11 @@ class CLIPModel(CLIPPreTrainedModel): ...@@ -910,11 +915,11 @@ class CLIPModel(CLIPPreTrainedModel):
@add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING) @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
def get_image_features( def get_image_features(
self, self,
pixel_values=None, pixel_values: Optional[torch.FloatTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> torch.FloatTensor:
r""" r"""
Returns: Returns:
image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by
...@@ -953,15 +958,15 @@ class CLIPModel(CLIPPreTrainedModel): ...@@ -953,15 +958,15 @@ class CLIPModel(CLIPPreTrainedModel):
@replace_return_docstrings(output_type=CLIPOutput, config_class=CLIPConfig) @replace_return_docstrings(output_type=CLIPOutput, config_class=CLIPConfig)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
pixel_values=None, pixel_values: Optional[torch.FloatTensor] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
return_loss=None, return_loss: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, CLIPOutput]:
r""" r"""
Returns: Returns:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment