Unverified Commit b18dfd95 authored by Anmol Joshi's avatar Anmol Joshi Committed by GitHub
Browse files

added type hints to CTRL pytorch (#16593)

* Completed documentation of CTRL

* Missing optional None

* Added return types

* updated imports

* Update modeling_ctrl.py
parent 208f4c10
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
""" PyTorch CTRL model.""" """ PyTorch CTRL model."""
from typing import Tuple from typing import Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
...@@ -359,18 +359,18 @@ class CTRLModel(CTRLPreTrainedModel): ...@@ -359,18 +359,18 @@ class CTRLModel(CTRLPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
...@@ -521,19 +521,19 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ...@@ -521,19 +521,19 @@ class CTRLLMHeadModel(CTRLPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: torch.LongTensor = None,
head_mask=None, head_mask: torch.FloatTensor = None,
inputs_embeds=None, inputs_embeds: torch.FloatTensor = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithPast]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
...@@ -625,19 +625,19 @@ class CTRLForSequenceClassification(CTRLPreTrainedModel): ...@@ -625,19 +625,19 @@ class CTRLForSequenceClassification(CTRLPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
past_key_values=None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment