Unverified Commit 7f3d4440 authored by João Gustavo A. Amorim's avatar João Gustavo A. Amorim Committed by GitHub
Browse files

add type annotations for ImageGPT (#16088)

parent 5b4c97d0
......@@ -17,7 +17,7 @@
import math
import os
import warnings
from typing import Tuple
from typing import Any, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
......@@ -167,12 +167,12 @@ def load_tf_weights_in_imagegpt(model, config, imagegpt_checkpoint_path):
class ImageGPTLayerNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-5):
def __init__(self, hidden_size: Tuple[int], eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.Tensor(hidden_size))
def forward(self, tensor):
def forward(self, tensor: torch.Tensor) -> tuple:
# input is not mean centered
return (
tensor
......@@ -182,7 +182,7 @@ class ImageGPTLayerNorm(nn.Module):
class ImageGPTAttention(nn.Module):
def __init__(self, config, is_cross_attention=False, layer_idx=None):
def __init__(self, config, is_cross_attention: Optional[bool] = False, layer_idx: Optional[int] = None):
super().__init__()
max_positions = config.max_position_embeddings
......@@ -343,15 +343,15 @@ class ImageGPTAttention(nn.Module):
def forward(
self,
hidden_states,
layer_past=None,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=False,
output_attentions=False,
):
hidden_states: torch.Tensor,
layer_past: Optional[bool] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> tuple:
if encoder_hidden_states is not None:
if not hasattr(self, "q_attn"):
raise ValueError(
......@@ -404,7 +404,7 @@ class ImageGPTMLP(nn.Module):
self.act = ACT2FN[config.activation_function]
self.dropout = nn.Dropout(config.resid_pdrop)
def forward(self, hidden_states):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states = self.c_proj(hidden_states)
......@@ -430,15 +430,15 @@ class ImageGPTBlock(nn.Module):
def forward(
self,
hidden_states,
layer_past=None,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=False,
output_attentions=False,
):
hidden_states: torch.Tensor,
layer_past: Optional[bool] = None,
attention_mask: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
) -> tuple:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn(
......@@ -620,7 +620,7 @@ IMAGEGPT_INPUTS_DOCSTRING = r"""
class ImageGPTModel(ImageGPTPreTrainedModel):
_keys_to_ignore_on_load_missing = ["attn.masked_bias"]
def __init__(self, config):
def __init__(self, config: ImageGPTConfig):
super().__init__(config)
self.embed_dim = config.hidden_size
......@@ -656,21 +656,21 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
@replace_return_docstrings(output_type=BaseModelOutputWithPastAndCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
input_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs: Any,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
......@@ -900,7 +900,7 @@ class ImageGPTModel(ImageGPTPreTrainedModel):
class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
def __init__(self, config):
def __init__(self, config: ImageGPTConfig):
super().__init__(config)
self.transformer = ImageGPTModel(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size - 1, bias=False)
......@@ -917,7 +917,7 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
def prepare_inputs_for_generation(self, input_ids: torch.Tensor, past: Optional[bool] = None, **kwargs):
token_type_ids = kwargs.get("token_type_ids", None)
# only last token for inputs_ids if past is defined in kwargs
if past:
......@@ -949,22 +949,22 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
input_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs: Any,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
......@@ -1088,7 +1088,7 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel):
class ImageGPTForImageClassification(ImageGPTPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
def __init__(self, config):
def __init__(self, config: ImageGPTConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.transformer = ImageGPTModel(config)
......@@ -1101,20 +1101,20 @@ class ImageGPTForImageClassification(ImageGPTPreTrainedModel):
@replace_return_docstrings(output_type=SequenceClassifierOutputWithPast, config_class=_CONFIG_FOR_DOC)
def forward(
self,
input_ids=None,
past_key_values=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs,
):
input_ids: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs: Any,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment