Unverified Commit 7f3d4440 authored by João Gustavo A. Amorim's avatar João Gustavo A. Amorim Committed by GitHub
Browse files

add type annotations for ImageGPT (#16088)

parent 5b4c97d0
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import math import math
import os import os
import warnings import warnings
from typing import Tuple from typing import Any, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -167,12 +167,12 @@ def load_tf_weights_in_imagegpt(model, config, imagegpt_checkpoint_path): ...@@ -167,12 +167,12 @@ def load_tf_weights_in_imagegpt(model, config, imagegpt_checkpoint_path):
class ImageGPTLayerNorm(nn.Module): class ImageGPTLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-5): def __init__(self, hidden_size: Tuple[int], eps: float = 1e-5):
super().__init__() super().__init__()
self.eps = eps self.eps = eps
self.weight = nn.Parameter(torch.Tensor(hidden_size)) self.weight = nn.Parameter(torch.Tensor(hidden_size))
def forward(self, tensor): def forward(self, tensor: torch.Tensor) -> tuple:
# input is not mean centered # input is not mean centered
return ( return (
tensor tensor
...@@ -182,7 +182,7 @@ class ImageGPTLayerNorm(nn.Module): ...@@ -182,7 +182,7 @@ class ImageGPTLayerNorm(nn.Module):
class ImageGPTAttention(nn.Module): class ImageGPTAttention(nn.Module):
def __init__(self, config, is_cross_attention=False, layer_idx=None): def __init__(self, config, is_cross_attention: Optional[bool] = False, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
max_positions = config.max_position_embeddings max_positions = config.max_position_embeddings
...@@ -343,15 +343,15 @@ class ImageGPTAttention(nn.Module): ...@@ -343,15 +343,15 @@ class ImageGPTAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
layer_past=None, layer_past: Optional[bool] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache=False, use_cache: Optional[bool] = False,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> tuple:
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"): if not hasattr(self, "q_attn"):
raise ValueError( raise ValueError(
...@@ -404,7 +404,7 @@ class ImageGPTMLP(nn.Module): ...@@ -404,7 +404,7 @@ class ImageGPTMLP(nn.Module):
self.act = ACT2FN[config.activation_function] self.act = ACT2FN[config.activation_function]
self.dropout = nn.Dropout(config.resid_pdrop) self.dropout = nn.Dropout(config.resid_pdrop)
def forward(self, hidden_states): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.c_fc(hidden_states) hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states) hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states) hidden_states = self.c_proj(hidden_states)
...@@ -430,15 +430,15 @@ class ImageGPTBlock(nn.Module): ...@@ -430,15 +430,15 @@ class ImageGPTBlock(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states: torch.Tensor,
layer_past=None, layer_past: Optional[bool] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache=False, use_cache: Optional[bool] = False,
output_attentions=False, output_attentions: Optional[bool] = False,
): ) -> tuple:
residual = hidden_states residual = hidden_states
hidden_states = self.ln_1(hidden_states) hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn( attn_outputs = self.attn(
...@@ -620,7 +620,7 @@ IMAGEGPT_INPUTS_DOCSTRING = r""" ...@@ -620,7 +620,7 @@ IMAGEGPT_INPUTS_DOCSTRING = r"""
class ImageGPTModel(ImageGPTPreTrainedModel): class ImageGPTModel(ImageGPTPreTrainedModel):
_keys_to_ignore_on_load_missing = ["attn.masked_bias"] _keys_to_ignore_on_load_missing = ["attn.masked_bias"]
def __init__(self, config): def __init__(self, config: ImageGPTConfig):
super().__init__(config) super().__init__(config)
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
...@@ -656,21 +656,21 @@ class ImageGPTModel(ImageGPTPreTrainedModel): ...@@ -656,21 +656,21 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
@replace_return_docstrings(output_type=BaseModelOutputWithPastAndCrossAttentions, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=BaseModelOutputWithPastAndCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.Tensor] = None,
position_ids=None, position_ids: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
**kwargs, **kwargs: Any,
): ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
...@@ -900,7 +900,7 @@ class ImageGPTModel(ImageGPTPreTrainedModel): ...@@ -900,7 +900,7 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel): class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"] _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
def __init__(self, config): def __init__(self, config: ImageGPTConfig):
super().__init__(config) super().__init__(config)
self.transformer = ImageGPTModel(config) self.transformer = ImageGPTModel(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size - 1, bias=False) self.lm_head = nn.Linear(config.n_embd, config.vocab_size - 1, bias=False)
...@@ -917,7 +917,7 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel): ...@@ -917,7 +917,7 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
def set_output_embeddings(self, new_embeddings): def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs): def prepare_inputs_for_generation(self, input_ids: torch.Tensor, past: Optional[bool] = None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None) token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs # only last token for inputs_ids if past is defined in kwargs
if past: if past:
...@@ -949,22 +949,22 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel): ...@@ -949,22 +949,22 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.Tensor] = None,
position_ids=None, position_ids: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states=None, encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask=None, encoder_attention_mask: Optional[torch.Tensor] = None,
labels=None, labels: Optional[torch.Tensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
**kwargs, **kwargs: Any,
): ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
...@@ -1088,7 +1088,7 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel): ...@@ -1088,7 +1088,7 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
class ImageGPTForImageClassification(ImageGPTPreTrainedModel): class ImageGPTForImageClassification(ImageGPTPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"] _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
def __init__(self, config): def __init__(self, config: ImageGPTConfig):
super().__init__(config) super().__init__(config)
self.num_labels = config.num_labels self.num_labels = config.num_labels
self.transformer = ImageGPTModel(config) self.transformer = ImageGPTModel(config)
...@@ -1101,20 +1101,20 @@ class ImageGPTForImageClassification(ImageGPTPreTrainedModel): ...@@ -1101,20 +1101,20 @@ class ImageGPTForImageClassification(ImageGPTPreTrainedModel):
@replace_return_docstrings(output_type=SequenceClassifierOutputWithPast, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=SequenceClassifierOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.Tensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.Tensor] = None,
position_ids=None, position_ids: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
labels=None, labels: Optional[torch.Tensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
**kwargs, **kwargs: Any,
): ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment