Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c4a0fb51
"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "d8e3bdbb4cce939e8f95e0f1fa33bdd7350f4b79"
Unverified
Commit
c4a0fb51
authored
Dec 16, 2021
by
Patrick von Platen
Committed by
GitHub
Dec 16, 2021
Browse files
[WavLM] Correct position bias computation (#14805)
parent
d194d639
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
5 deletions
+15
-5
src/transformers/models/wavlm/modeling_wavlm.py
src/transformers/models/wavlm/modeling_wavlm.py
+15
-5
No files found.
src/transformers/models/wavlm/modeling_wavlm.py
View file @
c4a0fb51
...
...
@@ -394,6 +394,7 @@ class WavLMAttention(nn.Module):
dropout
:
float
=
0.0
,
num_buckets
:
int
=
320
,
max_distance
:
int
=
800
,
has_relative_position_bias
:
bool
=
True
,
):
super
().
__init__
()
self
.
embed_dim
=
embed_dim
...
...
@@ -418,7 +419,9 @@ class WavLMAttention(nn.Module):
self
.
gru_rel_pos_const
=
nn
.
Parameter
(
torch
.
ones
(
1
,
self
.
num_heads
,
1
,
1
))
self
.
gru_rel_pos_linear
=
nn
.
Linear
(
self
.
head_dim
,
8
)
self
.
rel_attn_embed
=
nn
.
Embedding
(
self
.
num_buckets
,
self
.
num_heads
)
if
has_relative_position_bias
:
self
.
rel_attn_embed
=
nn
.
Embedding
(
self
.
num_buckets
,
self
.
num_heads
)
def
forward
(
self
,
...
...
@@ -573,7 +576,7 @@ class WavLMFeedForward(nn.Module):
class
WavLMEncoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
:
WavLMConfig
,
has_relative_position_bias
:
bool
=
True
):
super
().
__init__
()
self
.
attention
=
WavLMAttention
(
embed_dim
=
config
.
hidden_size
,
...
...
@@ -581,6 +584,7 @@ class WavLMEncoderLayer(nn.Module):
dropout
=
config
.
attention_dropout
,
num_buckets
=
config
.
num_buckets
,
max_distance
=
config
.
max_bucket_distance
,
has_relative_position_bias
=
has_relative_position_bias
,
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout
)
self
.
layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
...
...
@@ -613,7 +617,7 @@ class WavLMEncoderLayer(nn.Module):
class
WavLMEncoderLayerStableLayerNorm
(
nn
.
Module
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
:
WavLMConfig
,
has_relative_position_bias
:
bool
=
True
):
super
().
__init__
()
self
.
attention
=
WavLMAttention
(
embed_dim
=
config
.
hidden_size
,
...
...
@@ -621,6 +625,7 @@ class WavLMEncoderLayerStableLayerNorm(nn.Module):
dropout
=
config
.
attention_dropout
,
num_buckets
=
config
.
num_buckets
,
max_distance
=
config
.
max_bucket_distance
,
has_relative_position_bias
=
has_relative_position_bias
,
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout
)
self
.
layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
...
...
@@ -655,7 +660,9 @@ class WavLMEncoder(nn.Module):
self
.
pos_conv_embed
=
WavLMPositionalConvEmbedding
(
config
)
self
.
layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout
)
self
.
layers
=
nn
.
ModuleList
([
WavLMEncoderLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
layers
=
nn
.
ModuleList
(
[
WavLMEncoderLayer
(
config
,
has_relative_position_bias
=
(
i
==
0
))
for
i
in
range
(
config
.
num_hidden_layers
)]
)
self
.
gradient_checkpointing
=
False
def
forward
(
...
...
@@ -743,7 +750,10 @@ class WavLMEncoderStableLayerNorm(nn.Module):
self
.
layer_norm
=
nn
.
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layer_norm_eps
)
self
.
dropout
=
nn
.
Dropout
(
config
.
hidden_dropout
)
self
.
layers
=
nn
.
ModuleList
(
[
WavLMEncoderLayerStableLayerNorm
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)]
[
WavLMEncoderLayerStableLayerNorm
(
config
,
has_relative_position_bias
=
(
i
==
0
))
for
i
in
range
(
config
.
num_hidden_layers
)
]
)
self
.
gradient_checkpointing
=
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment