"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "d8e3bdbb4cce939e8f95e0f1fa33bdd7350f4b79"
Unverified Commit c4a0fb51 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[WavLM] Correct position bias computation (#14805)

parent d194d639
......@@ -394,6 +394,7 @@ class WavLMAttention(nn.Module):
dropout: float = 0.0,
num_buckets: int = 320,
max_distance: int = 800,
has_relative_position_bias: bool = True,
):
super().__init__()
self.embed_dim = embed_dim
......@@ -418,7 +419,9 @@ class WavLMAttention(nn.Module):
self.gru_rel_pos_const = nn.Parameter(torch.ones(1, self.num_heads, 1, 1))
self.gru_rel_pos_linear = nn.Linear(self.head_dim, 8)
self.rel_attn_embed = nn.Embedding(self.num_buckets, self.num_heads)
if has_relative_position_bias:
self.rel_attn_embed = nn.Embedding(self.num_buckets, self.num_heads)
def forward(
self,
......@@ -573,7 +576,7 @@ class WavLMFeedForward(nn.Module):
class WavLMEncoderLayer(nn.Module):
def __init__(self, config):
def __init__(self, config: WavLMConfig, has_relative_position_bias: bool = True):
super().__init__()
self.attention = WavLMAttention(
embed_dim=config.hidden_size,
......@@ -581,6 +584,7 @@ class WavLMEncoderLayer(nn.Module):
dropout=config.attention_dropout,
num_buckets=config.num_buckets,
max_distance=config.max_bucket_distance,
has_relative_position_bias=has_relative_position_bias,
)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
......@@ -613,7 +617,7 @@ class WavLMEncoderLayer(nn.Module):
class WavLMEncoderLayerStableLayerNorm(nn.Module):
def __init__(self, config):
def __init__(self, config: WavLMConfig, has_relative_position_bias: bool = True):
super().__init__()
self.attention = WavLMAttention(
embed_dim=config.hidden_size,
......@@ -621,6 +625,7 @@ class WavLMEncoderLayerStableLayerNorm(nn.Module):
dropout=config.attention_dropout,
num_buckets=config.num_buckets,
max_distance=config.max_bucket_distance,
has_relative_position_bias=has_relative_position_bias,
)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
......@@ -655,7 +660,9 @@ class WavLMEncoder(nn.Module):
self.pos_conv_embed = WavLMPositionalConvEmbedding(config)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.ModuleList([WavLMEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.layers = nn.ModuleList(
[WavLMEncoderLayer(config, has_relative_position_bias=(i == 0)) for i in range(config.num_hidden_layers)]
)
self.gradient_checkpointing = False
def forward(
......@@ -743,7 +750,10 @@ class WavLMEncoderStableLayerNorm(nn.Module):
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.ModuleList(
[WavLMEncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)]
[
WavLMEncoderLayerStableLayerNorm(config, has_relative_position_bias=(i == 0))
for i in range(config.num_hidden_layers)
]
)
self.gradient_checkpointing = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment