Unverified Commit c4a0fb51 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[WavLM] Correct position bias computation (#14805)

parent d194d639
...@@ -394,6 +394,7 @@ class WavLMAttention(nn.Module): ...@@ -394,6 +394,7 @@ class WavLMAttention(nn.Module):
dropout: float = 0.0, dropout: float = 0.0,
num_buckets: int = 320, num_buckets: int = 320,
max_distance: int = 800, max_distance: int = 800,
has_relative_position_bias: bool = True,
): ):
super().__init__() super().__init__()
self.embed_dim = embed_dim self.embed_dim = embed_dim
...@@ -418,6 +419,8 @@ class WavLMAttention(nn.Module): ...@@ -418,6 +419,8 @@ class WavLMAttention(nn.Module):
self.gru_rel_pos_const = nn.Parameter(torch.ones(1, self.num_heads, 1, 1)) self.gru_rel_pos_const = nn.Parameter(torch.ones(1, self.num_heads, 1, 1))
self.gru_rel_pos_linear = nn.Linear(self.head_dim, 8) self.gru_rel_pos_linear = nn.Linear(self.head_dim, 8)
if has_relative_position_bias:
self.rel_attn_embed = nn.Embedding(self.num_buckets, self.num_heads) self.rel_attn_embed = nn.Embedding(self.num_buckets, self.num_heads)
def forward( def forward(
...@@ -573,7 +576,7 @@ class WavLMFeedForward(nn.Module): ...@@ -573,7 +576,7 @@ class WavLMFeedForward(nn.Module):
class WavLMEncoderLayer(nn.Module): class WavLMEncoderLayer(nn.Module):
def __init__(self, config): def __init__(self, config: WavLMConfig, has_relative_position_bias: bool = True):
super().__init__() super().__init__()
self.attention = WavLMAttention( self.attention = WavLMAttention(
embed_dim=config.hidden_size, embed_dim=config.hidden_size,
...@@ -581,6 +584,7 @@ class WavLMEncoderLayer(nn.Module): ...@@ -581,6 +584,7 @@ class WavLMEncoderLayer(nn.Module):
dropout=config.attention_dropout, dropout=config.attention_dropout,
num_buckets=config.num_buckets, num_buckets=config.num_buckets,
max_distance=config.max_bucket_distance, max_distance=config.max_bucket_distance,
has_relative_position_bias=has_relative_position_bias,
) )
self.dropout = nn.Dropout(config.hidden_dropout) self.dropout = nn.Dropout(config.hidden_dropout)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
...@@ -613,7 +617,7 @@ class WavLMEncoderLayer(nn.Module): ...@@ -613,7 +617,7 @@ class WavLMEncoderLayer(nn.Module):
class WavLMEncoderLayerStableLayerNorm(nn.Module): class WavLMEncoderLayerStableLayerNorm(nn.Module):
def __init__(self, config): def __init__(self, config: WavLMConfig, has_relative_position_bias: bool = True):
super().__init__() super().__init__()
self.attention = WavLMAttention( self.attention = WavLMAttention(
embed_dim=config.hidden_size, embed_dim=config.hidden_size,
...@@ -621,6 +625,7 @@ class WavLMEncoderLayerStableLayerNorm(nn.Module): ...@@ -621,6 +625,7 @@ class WavLMEncoderLayerStableLayerNorm(nn.Module):
dropout=config.attention_dropout, dropout=config.attention_dropout,
num_buckets=config.num_buckets, num_buckets=config.num_buckets,
max_distance=config.max_bucket_distance, max_distance=config.max_bucket_distance,
has_relative_position_bias=has_relative_position_bias,
) )
self.dropout = nn.Dropout(config.hidden_dropout) self.dropout = nn.Dropout(config.hidden_dropout)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
...@@ -655,7 +660,9 @@ class WavLMEncoder(nn.Module): ...@@ -655,7 +660,9 @@ class WavLMEncoder(nn.Module):
self.pos_conv_embed = WavLMPositionalConvEmbedding(config) self.pos_conv_embed = WavLMPositionalConvEmbedding(config)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout) self.dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.ModuleList([WavLMEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList(
[WavLMEncoderLayer(config, has_relative_position_bias=(i == 0)) for i in range(config.num_hidden_layers)]
)
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward( def forward(
...@@ -743,7 +750,10 @@ class WavLMEncoderStableLayerNorm(nn.Module): ...@@ -743,7 +750,10 @@ class WavLMEncoderStableLayerNorm(nn.Module):
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout) self.dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.ModuleList( self.layers = nn.ModuleList(
[WavLMEncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)] [
WavLMEncoderLayerStableLayerNorm(config, has_relative_position_bias=(i == 0))
for i in range(config.num_hidden_layers)
]
) )
self.gradient_checkpointing = False self.gradient_checkpointing = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment