Unverified Commit 6c2f3ed7 authored by Bhavika Tekwani's avatar Bhavika Tekwani Committed by GitHub
Browse files

Add type hints for Luke in PyTorch (#16111)



* Add type hints for LukeModel

* Add type hints for entitypairclassification

* Remove blank space
Co-authored-by: default avatarbhavika <bhavika@debian-BULLSEYE-live-builder-AMD64>
parent 37a9fc49
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import math import math
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Tuple from typing import Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -880,7 +880,7 @@ class LukeModel(LukePreTrainedModel): ...@@ -880,7 +880,7 @@ class LukeModel(LukePreTrainedModel):
_keys_to_ignore_on_load_missing = [r"position_ids"] _keys_to_ignore_on_load_missing = [r"position_ids"]
def __init__(self, config, add_pooling_layer=True): def __init__(self, config: LukeConfig, add_pooling_layer: bool = True):
super().__init__(config) super().__init__(config)
self.config = config self.config = config
...@@ -912,20 +912,20 @@ class LukeModel(LukePreTrainedModel): ...@@ -912,20 +912,20 @@ class LukeModel(LukePreTrainedModel):
@replace_return_docstrings(output_type=BaseLukeModelOutputWithPooling, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=BaseLukeModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
entity_ids=None, entity_ids: Optional[torch.LongTensor] = None,
entity_attention_mask=None, entity_attention_mask: Optional[torch.FloatTensor] = None,
entity_token_type_ids=None, entity_token_type_ids: Optional[torch.LongTensor] = None,
entity_position_ids=None, entity_position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, BaseLukeModelOutputWithPooling]:
r""" r"""
Returns: Returns:
...@@ -1169,22 +1169,22 @@ class LukeForMaskedLM(LukePreTrainedModel): ...@@ -1169,22 +1169,22 @@ class LukeForMaskedLM(LukePreTrainedModel):
@replace_return_docstrings(output_type=LukeMaskedLMOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=LukeMaskedLMOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
entity_ids=None, entity_ids: Optional[torch.LongTensor] = None,
entity_attention_mask=None, entity_attention_mask: Optional[torch.LongTensor] = None,
entity_token_type_ids=None, entity_token_type_ids: Optional[torch.LongTensor] = None,
entity_position_ids=None, entity_position_ids: Optional[torch.LongTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
entity_labels=None, entity_labels: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, LukeMaskedLMOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
...@@ -1282,21 +1282,21 @@ class LukeForEntityClassification(LukePreTrainedModel): ...@@ -1282,21 +1282,21 @@ class LukeForEntityClassification(LukePreTrainedModel):
@replace_return_docstrings(output_type=EntityClassificationOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=EntityClassificationOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
entity_ids=None, entity_ids: Optional[torch.LongTensor] = None,
entity_attention_mask=None, entity_attention_mask: Optional[torch.FloatTensor] = None,
entity_token_type_ids=None, entity_token_type_ids: Optional[torch.LongTensor] = None,
entity_position_ids=None, entity_position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.FloatTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, EntityClassificationOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)` or `(batch_size, num_labels)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)` or `(batch_size, num_labels)`, *optional*):
Labels for computing the classification loss. If the shape is `(batch_size,)`, the cross entropy loss is Labels for computing the classification loss. If the shape is `(batch_size,)`, the cross entropy loss is
...@@ -1397,21 +1397,21 @@ class LukeForEntityPairClassification(LukePreTrainedModel): ...@@ -1397,21 +1397,21 @@ class LukeForEntityPairClassification(LukePreTrainedModel):
@replace_return_docstrings(output_type=EntityPairClassificationOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=EntityPairClassificationOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
entity_ids=None, entity_ids: Optional[torch.LongTensor] = None,
entity_attention_mask=None, entity_attention_mask: Optional[torch.FloatTensor] = None,
entity_token_type_ids=None, entity_token_type_ids: Optional[torch.LongTensor] = None,
entity_position_ids=None, entity_position_ids: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, EntityPairClassificationOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)` or `(batch_size, num_labels)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)` or `(batch_size, num_labels)`, *optional*):
Labels for computing the classification loss. If the shape is `(batch_size,)`, the cross entropy loss is Labels for computing the classification loss. If the shape is `(batch_size,)`, the cross entropy loss is
...@@ -1517,23 +1517,23 @@ class LukeForEntitySpanClassification(LukePreTrainedModel): ...@@ -1517,23 +1517,23 @@ class LukeForEntitySpanClassification(LukePreTrainedModel):
@replace_return_docstrings(output_type=EntitySpanClassificationOutput, config_class=_CONFIG_FOR_DOC) @replace_return_docstrings(output_type=EntitySpanClassificationOutput, config_class=_CONFIG_FOR_DOC)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask=None,
token_type_ids=None, token_type_ids: Optional[torch.LongTensor] = None,
position_ids=None, position_ids: Optional[torch.LongTensor] = None,
entity_ids=None, entity_ids: Optional[torch.LongTensor] = None,
entity_attention_mask=None, entity_attention_mask: Optional[torch.LongTensor] = None,
entity_token_type_ids=None, entity_token_type_ids: Optional[torch.LongTensor] = None,
entity_position_ids=None, entity_position_ids: Optional[torch.LongTensor] = None,
entity_start_positions=None, entity_start_positions: Optional[torch.LongTensor] = None,
entity_end_positions=None, entity_end_positions: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
labels=None, labels: Optional[torch.LongTensor] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple, EntitySpanClassificationOutput]:
r""" r"""
entity_start_positions (`torch.LongTensor`): entity_start_positions (`torch.LongTensor`):
The start positions of entities in the word token sequence. The start positions of entities in the word token sequence.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment