Unverified Commit bd1a43b6 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[S2T, Whisper] Add copied from statements (#20787)

* [S2T, Whisper] Add copied from statements

* rebase and fix-copies
parent 5eecf3ff
...@@ -354,6 +354,7 @@ class Speech2TextAttention(nn.Module): ...@@ -354,6 +354,7 @@ class Speech2TextAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value return attn_output, attn_weights_reshaped, past_key_value
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Speech2Text
class Speech2TextEncoderLayer(nn.Module): class Speech2TextEncoderLayer(nn.Module):
def __init__(self, config: Speech2TextConfig): def __init__(self, config: Speech2TextConfig):
super().__init__() super().__init__()
...@@ -377,14 +378,14 @@ class Speech2TextEncoderLayer(nn.Module): ...@@ -377,14 +378,14 @@ class Speech2TextEncoderLayer(nn.Module):
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
layer_head_mask: torch.Tensor, layer_head_mask: torch.Tensor,
output_attentions: bool = False, output_attentions: bool = False,
): ) -> torch.Tensor:
""" """
Args: Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(encoder_attention_heads,)`.
output_attentions (`bool`, *optional*): output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail. returned tensors for more detail.
...@@ -422,6 +423,7 @@ class Speech2TextEncoderLayer(nn.Module): ...@@ -422,6 +423,7 @@ class Speech2TextEncoderLayer(nn.Module):
return outputs return outputs
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Speech2Text
class Speech2TextDecoderLayer(nn.Module): class Speech2TextDecoderLayer(nn.Module):
def __init__(self, config: Speech2TextConfig): def __init__(self, config: Speech2TextConfig):
super().__init__() super().__init__()
...@@ -460,7 +462,7 @@ class Speech2TextDecoderLayer(nn.Module): ...@@ -460,7 +462,7 @@ class Speech2TextDecoderLayer(nn.Module):
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True, use_cache: Optional[bool] = True,
): ) -> torch.Tensor:
""" """
Args: Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
...@@ -473,7 +475,7 @@ class Speech2TextDecoderLayer(nn.Module): ...@@ -473,7 +475,7 @@ class Speech2TextDecoderLayer(nn.Module):
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`. `(encoder_attention_heads,)`.
cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
size *(decoder_attention_heads,)*. size `(decoder_attention_heads,)`.
past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (`bool`, *optional*): output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under Whether or not to return the attentions tensors of all attention layers. See `attentions` under
......
...@@ -261,7 +261,7 @@ class WhisperAttention(nn.Module): ...@@ -261,7 +261,7 @@ class WhisperAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value return attn_output, attn_weights_reshaped, past_key_value
# Copied from transformers.models.speech_to_text.modeling_speech_to_text.Speech2TextEncoderLayer with Speech2Text->Whisper # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper
class WhisperEncoderLayer(nn.Module): class WhisperEncoderLayer(nn.Module):
def __init__(self, config: WhisperConfig): def __init__(self, config: WhisperConfig):
super().__init__() super().__init__()
...@@ -285,14 +285,14 @@ class WhisperEncoderLayer(nn.Module): ...@@ -285,14 +285,14 @@ class WhisperEncoderLayer(nn.Module):
attention_mask: torch.Tensor, attention_mask: torch.Tensor,
layer_head_mask: torch.Tensor, layer_head_mask: torch.Tensor,
output_attentions: bool = False, output_attentions: bool = False,
): ) -> torch.Tensor:
""" """
Args: Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`. `(encoder_attention_heads,)`.
output_attentions (`bool`, *optional*): output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail. returned tensors for more detail.
...@@ -330,7 +330,7 @@ class WhisperEncoderLayer(nn.Module): ...@@ -330,7 +330,7 @@ class WhisperEncoderLayer(nn.Module):
return outputs return outputs
# Copied from transformers.models.speech_to_text.modeling_speech_to_text.Speech2TextDecoderLayer with Speech2Text->Whisper # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Whisper
class WhisperDecoderLayer(nn.Module): class WhisperDecoderLayer(nn.Module):
def __init__(self, config: WhisperConfig): def __init__(self, config: WhisperConfig):
super().__init__() super().__init__()
...@@ -369,7 +369,7 @@ class WhisperDecoderLayer(nn.Module): ...@@ -369,7 +369,7 @@ class WhisperDecoderLayer(nn.Module):
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False, output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True, use_cache: Optional[bool] = True,
): ) -> torch.Tensor:
""" """
Args: Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)` hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
...@@ -382,7 +382,7 @@ class WhisperDecoderLayer(nn.Module): ...@@ -382,7 +382,7 @@ class WhisperDecoderLayer(nn.Module):
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`. `(encoder_attention_heads,)`.
cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
size *(decoder_attention_heads,)*. size `(decoder_attention_heads,)`.
past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (`bool`, *optional*): output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under Whether or not to return the attentions tensors of all attention layers. See `attentions` under
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment