Unverified Commit bd1a43b6 authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[S2T, Whisper] Add copied from statements (#20787)

* [S2T, Whisper] Add copied from statements

* rebase and fix-copies
parent 5eecf3ff
......@@ -354,6 +354,7 @@ class Speech2TextAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Speech2Text
class Speech2TextEncoderLayer(nn.Module):
def __init__(self, config: Speech2TextConfig):
super().__init__()
......@@ -377,14 +378,14 @@ class Speech2TextEncoderLayer(nn.Module):
attention_mask: torch.Tensor,
layer_head_mask: torch.Tensor,
output_attentions: bool = False,
):
) -> torch.Tensor:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`.
`(encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
......@@ -422,6 +423,7 @@ class Speech2TextEncoderLayer(nn.Module):
return outputs
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Speech2Text
class Speech2TextDecoderLayer(nn.Module):
def __init__(self, config: Speech2TextConfig):
super().__init__()
......@@ -460,7 +462,7 @@ class Speech2TextDecoderLayer(nn.Module):
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True,
):
) -> torch.Tensor:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
......@@ -473,7 +475,7 @@ class Speech2TextDecoderLayer(nn.Module):
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`.
cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
size *(decoder_attention_heads,)*.
size `(decoder_attention_heads,)`.
past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
......
......@@ -261,7 +261,7 @@ class WhisperAttention(nn.Module):
return attn_output, attn_weights_reshaped, past_key_value
# Copied from transformers.models.speech_to_text.modeling_speech_to_text.Speech2TextEncoderLayer with Speech2Text->Whisper
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper
class WhisperEncoderLayer(nn.Module):
def __init__(self, config: WhisperConfig):
super().__init__()
......@@ -285,14 +285,14 @@ class WhisperEncoderLayer(nn.Module):
attention_mask: torch.Tensor,
layer_head_mask: torch.Tensor,
output_attentions: bool = False,
):
) -> torch.Tensor:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
attention_mask (`torch.FloatTensor`): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
`(config.encoder_attention_heads,)`.
`(encoder_attention_heads,)`.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
......@@ -330,7 +330,7 @@ class WhisperEncoderLayer(nn.Module):
return outputs
# Copied from transformers.models.speech_to_text.modeling_speech_to_text.Speech2TextDecoderLayer with Speech2Text->Whisper
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Whisper
class WhisperDecoderLayer(nn.Module):
def __init__(self, config: WhisperConfig):
super().__init__()
......@@ -369,7 +369,7 @@ class WhisperDecoderLayer(nn.Module):
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True,
):
) -> torch.Tensor:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
......@@ -382,7 +382,7 @@ class WhisperDecoderLayer(nn.Module):
layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
`(encoder_attention_heads,)`.
cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
size *(decoder_attention_heads,)*.
size `(decoder_attention_heads,)`.
past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment