Unverified Commit 94be4243 authored by Yi Heng Lim's avatar Yi Heng Lim Committed by GitHub
Browse files

Added type hints for PyTorch T5 model (#16257)

* Added type hints for PyTorch T5 model

* removed a type hint

* ran make style
parent 250b478a
...@@ -19,6 +19,7 @@ import copy ...@@ -19,6 +19,7 @@ import copy
import math import math
import os import os
import warnings import warnings
from typing import Optional, Tuple, Union
import torch import torch
from torch import nn from torch import nn
...@@ -275,7 +276,7 @@ except Exception: ...@@ -275,7 +276,7 @@ except Exception:
class T5DenseReluDense(nn.Module): class T5DenseReluDense(nn.Module):
def __init__(self, config): def __init__(self, config: T5Config):
super().__init__() super().__init__()
self.wi = nn.Linear(config.d_model, config.d_ff, bias=False) self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False) self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
...@@ -290,7 +291,7 @@ class T5DenseReluDense(nn.Module): ...@@ -290,7 +291,7 @@ class T5DenseReluDense(nn.Module):
class T5DenseGatedGeluDense(nn.Module): class T5DenseGatedGeluDense(nn.Module):
def __init__(self, config): def __init__(self, config: T5Config):
super().__init__() super().__init__()
self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False) self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False) self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
...@@ -308,7 +309,7 @@ class T5DenseGatedGeluDense(nn.Module): ...@@ -308,7 +309,7 @@ class T5DenseGatedGeluDense(nn.Module):
class T5LayerFF(nn.Module): class T5LayerFF(nn.Module):
def __init__(self, config): def __init__(self, config: T5Config):
super().__init__() super().__init__()
if config.feed_forward_proj == "relu": if config.feed_forward_proj == "relu":
self.DenseReluDense = T5DenseReluDense(config) self.DenseReluDense = T5DenseReluDense(config)
...@@ -1343,22 +1344,22 @@ class T5Model(T5PreTrainedModel): ...@@ -1343,22 +1344,22 @@ class T5Model(T5PreTrainedModel):
@replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
decoder_input_ids=None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.BoolTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
decoder_head_mask=None, decoder_head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs=None, encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.Tensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.Tensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
r""" r"""
Returns: Returns:
...@@ -1462,7 +1463,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1462,7 +1463,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
] ]
def __init__(self, config): def __init__(self, config: T5Config):
super().__init__(config) super().__init__(config)
self.model_dim = config.d_model self.model_dim = config.d_model
...@@ -1537,23 +1538,23 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1537,23 +1538,23 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
decoder_input_ids=None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.BoolTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
decoder_head_mask=None, decoder_head_mask: Optional[torch.FloatTensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs=None, encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
...@@ -1808,14 +1809,14 @@ class T5EncoderModel(T5PreTrainedModel): ...@@ -1808,14 +1809,14 @@ class T5EncoderModel(T5PreTrainedModel):
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
r""" r"""
Returns: Returns:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment