Unverified Commit 59a9c83e authored by Gunjan Chhablani's avatar Gunjan Chhablani Committed by GitHub
Browse files

Fix Bart type hints (#16297)

* Add type hints to PLBart PyTorch

* Remove pending merge conflicts

* Fix PLBart Type Hints

* Add changes from review
parent afc5a1ea
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import copy import copy
import math import math
import random import random
from typing import List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
...@@ -1142,21 +1142,21 @@ class PLBartModel(PLBartPreTrainedModel): ...@@ -1142,21 +1142,21 @@ class PLBartModel(PLBartPreTrainedModel):
) )
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids=None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
decoder_head_mask=None, decoder_head_mask: Optional[torch.LongTensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs=None, encoder_outputs: Optional[List[torch.FloatTensor]] = None,
past_key_values=None, past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds=None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = ( output_hidden_states = (
...@@ -1271,23 +1271,23 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel): ...@@ -1271,23 +1271,23 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel):
@add_end_docstrings(PLBART_GENERATION_EXAMPLE) @add_end_docstrings(PLBART_GENERATION_EXAMPLE)
def forward( def forward(
self, self,
input_ids=None, input_ids: Optional[torch.LongTensor] = None,
attention_mask=None, attention_mask: Optional[torch.LongTensor] = None,
decoder_input_ids=None, decoder_input_ids: Optional[torch.LongTensor] = None,
decoder_attention_mask=None, decoder_attention_mask: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
decoder_head_mask=None, decoder_head_mask: Optional[torch.LongTensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
encoder_outputs=None, encoder_outputs: Optional[List[torch.FloatTensor]] = None,
past_key_values=None, past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds=None, inputs_embeds: Optional[torch.FloatTensor] = None,
decoder_inputs_embeds=None, decoder_inputs_embeds=None,
labels=None, labels: Optional[torch.Tensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
output_attentions=None, output_attentions: Optional[bool] = None,
output_hidden_states=None, output_hidden_states: Optional[bool] = None,
return_dict=None, return_dict: Optional[bool] = None,
): ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
...@@ -1345,16 +1345,16 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel): ...@@ -1345,16 +1345,16 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel):
def prepare_inputs_for_generation( def prepare_inputs_for_generation(
self, self,
decoder_input_ids, decoder_input_ids: torch.LongTensor,
past=None, past: Optional[List[torch.FloatTensor]] = None,
attention_mask=None, attention_mask: Optional[torch.LongTensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
decoder_head_mask=None, decoder_head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
use_cache=None, use_cache: Optional[bool] = None,
encoder_outputs=None, encoder_outputs: Optional[List[torch.FloatTensor]] = None,
**kwargs # TODO: Check if this is needed. It is unused? **kwargs # TODO: Check if this is needed. It is unused?
): ) -> Dict[str, Any]:
# cut decoder_input_ids if past is used # cut decoder_input_ids if past is used
if past is not None: if past is not None:
decoder_input_ids = decoder_input_ids[:, -1:] decoder_input_ids = decoder_input_ids[:, -1:]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment