"tests/models/vscode:/vscode.git/clone" did not exist on "648bd5a8aad926d611d900d3d202dce37e085359"
Unverified Commit 77ea35b9 authored by Partho's avatar Partho Committed by GitHub
Browse files

added type hints (#19015)

parent fc21c9be
...@@ -468,12 +468,12 @@ class FSMTEncoder(nn.Module): ...@@ -468,12 +468,12 @@ class FSMTEncoder(nn.Module):
def forward( def forward(
self, self,
input_ids, input_ids: torch.Tensor,
attention_mask=None, attention_mask: Optional[torch.Tensor] = None,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
output_attentions=False, output_attentions: bool = False,
output_hidden_states=False, output_hidden_states: bool = False,
return_dict=True, return_dict: bool = True,
): ):
""" """
Args: Args:
...@@ -669,18 +669,18 @@ class FSMTDecoder(nn.Module): ...@@ -669,18 +669,18 @@ class FSMTDecoder(nn.Module):
def forward( def forward(
self, self,
input_ids, input_ids: torch.Tensor,
encoder_hidden_states, encoder_hidden_states: torch.Tensor,
encoder_padding_mask, encoder_padding_mask: torch.Tensor,
decoder_padding_mask, decoder_padding_mask: torch.Tensor,
decoder_causal_mask, decoder_causal_mask: torch.Tensor,
head_mask=None, head_mask: Optional[torch.Tensor] = None,
cross_attn_head_mask=None, cross_attn_head_mask: Optional[torch.Tensor] = None,
past_key_values=None, past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache=False, use_cache: bool = False,
output_attentions=False, output_attentions: bool = False,
output_hidden_states=False, output_hidden_states: bool = False,
return_dict=True, return_dict: bool = True,
): ):
""" """
Includes several features from "Jointly Learning to Align and Translate with Transformer Models" (Garg et al., Includes several features from "Jointly Learning to Align and Translate with Transformer Models" (Garg et al.,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment