Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9e00566b
Unverified
Commit
9e00566b
authored
Feb 09, 2022
by
Sanchit Gandhi
Committed by
GitHub
Feb 09, 2022
Browse files
Add Wav2Vec2 Adapter Weights to Flax (#15566)
* Add Wav2Vec2 Adapter Weights to Flax * Suggested changes
parent
1f60bc46
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
95 additions
and
3 deletions
+95
-3
src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py
src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py
+95
-3
No files found.
src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py
View file @
9e00566b
...
@@ -766,6 +766,73 @@ class FlaxWav2Vec2GumbelVectorQuantizer(nn.Module):
...
@@ -766,6 +766,73 @@ class FlaxWav2Vec2GumbelVectorQuantizer(nn.Module):
return
codevectors
,
perplexity
return
codevectors
,
perplexity
class
FlaxWav2Vec2Adapter
(
nn
.
Module
):
config
:
Wav2Vec2Config
dtype
:
jnp
.
dtype
=
jnp
.
float32
def
setup
(
self
):
# hidden_states require down-projection if feature dims don't match
if
self
.
config
.
output_hidden_size
!=
self
.
config
.
hidden_size
:
self
.
proj
=
nn
.
Dense
(
self
.
config
.
output_hidden_size
,
kernel_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
initializer_range
),
dtype
=
self
.
dtype
,
)
self
.
proj_layer_norm
=
nn
.
LayerNorm
(
epsilon
=
self
.
config
.
layer_norm_eps
,
dtype
=
self
.
dtype
)
else
:
self
.
proj
=
self
.
proj_layer_norm
=
None
self
.
layers
=
FlaxWav2Vec2AdapterLayersCollection
(
self
.
config
,
dtype
=
self
.
dtype
)
def
__call__
(
self
,
hidden_states
,
deterministic
=
True
):
# down-project hidden_states if required
if
self
.
proj
is
not
None
and
self
.
proj_layer_norm
is
not
None
:
hidden_states
=
self
.
proj
(
hidden_states
)
hidden_states
=
self
.
proj_layer_norm
(
hidden_states
)
hidden_states
=
self
.
layers
(
hidden_states
)
return
hidden_states
class
FlaxWav2Vec2AdapterLayer
(
nn
.
Module
):
config
:
Wav2Vec2Config
dtype
:
jnp
.
dtype
=
jnp
.
float32
def
setup
(
self
):
self
.
conv
=
nn
.
Conv
(
features
=
2
*
self
.
config
.
output_hidden_size
,
kernel_size
=
(
self
.
config
.
adapter_kernel_size
,),
strides
=
(
self
.
config
.
adapter_stride
,),
padding
=
((
1
,
1
),),
kernel_init
=
jax
.
nn
.
initializers
.
normal
(
self
.
config
.
initializer_range
),
dtype
=
self
.
dtype
,
)
def
__call__
(
self
,
hidden_states
):
hidden_states
=
self
.
conv
(
hidden_states
)
hidden_states
=
nn
.
glu
(
hidden_states
,
axis
=
2
)
return
hidden_states
class
FlaxWav2Vec2AdapterLayersCollection
(
nn
.
Module
):
config
:
Wav2Vec2Config
dtype
:
jnp
.
dtype
=
jnp
.
float32
def
setup
(
self
):
self
.
layers
=
[
FlaxWav2Vec2AdapterLayer
(
self
.
config
,
name
=
str
(
i
),
dtype
=
self
.
dtype
)
for
i
in
range
(
self
.
config
.
num_adapter_layers
)
]
def
__call__
(
self
,
hidden_states
):
for
conv_layer
in
self
.
layers
:
hidden_states
=
conv_layer
(
hidden_states
)
return
hidden_states
class
FlaxWav2Vec2PreTrainedModel
(
FlaxPreTrainedModel
):
class
FlaxWav2Vec2PreTrainedModel
(
FlaxPreTrainedModel
):
"""
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
...
@@ -840,7 +907,9 @@ class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
...
@@ -840,7 +907,9 @@ class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel):
rngs
=
rngs
,
rngs
=
rngs
,
)
)
def
_get_feat_extract_output_lengths
(
self
,
input_lengths
:
Union
[
jnp
.
ndarray
,
int
]):
def
_get_feat_extract_output_lengths
(
self
,
input_lengths
:
Union
[
jnp
.
ndarray
,
int
],
add_adapter
:
Optional
[
bool
]
=
None
):
return
self
.
module
.
_get_feat_extract_output_lengths
(
input_lengths
)
return
self
.
module
.
_get_feat_extract_output_lengths
(
input_lengths
)
...
@@ -860,6 +929,8 @@ class FlaxWav2Vec2Module(nn.Module):
...
@@ -860,6 +929,8 @@ class FlaxWav2Vec2Module(nn.Module):
else
:
else
:
raise
NotImplementedError
(
"``config.do_stable_layer_norm is False`` is currently not supported."
)
raise
NotImplementedError
(
"``config.do_stable_layer_norm is False`` is currently not supported."
)
self
.
adapter
=
FlaxWav2Vec2Adapter
(
self
.
config
,
dtype
=
self
.
dtype
)
if
self
.
config
.
add_adapter
else
None
def
__call__
(
def
__call__
(
self
,
self
,
input_values
,
input_values
,
...
@@ -905,6 +976,9 @@ class FlaxWav2Vec2Module(nn.Module):
...
@@ -905,6 +976,9 @@ class FlaxWav2Vec2Module(nn.Module):
hidden_states
=
encoder_outputs
[
0
]
hidden_states
=
encoder_outputs
[
0
]
if
self
.
adapter
is
not
None
:
hidden_states
=
self
.
adapter
(
hidden_states
)
if
not
return_dict
:
if
not
return_dict
:
return
(
hidden_states
,
extract_features
)
+
encoder_outputs
[
1
:]
return
(
hidden_states
,
extract_features
)
+
encoder_outputs
[
1
:]
...
@@ -915,11 +989,15 @@ class FlaxWav2Vec2Module(nn.Module):
...
@@ -915,11 +989,15 @@ class FlaxWav2Vec2Module(nn.Module):
attentions
=
encoder_outputs
.
attentions
,
attentions
=
encoder_outputs
.
attentions
,
)
)
def
_get_feat_extract_output_lengths
(
self
,
input_lengths
:
Union
[
jnp
.
ndarray
,
int
]):
def
_get_feat_extract_output_lengths
(
self
,
input_lengths
:
Union
[
jnp
.
ndarray
,
int
],
add_adapter
:
Optional
[
bool
]
=
None
):
"""
"""
Computes the output length of the convolutional layers
Computes the output length of the convolutional layers
"""
"""
add_adapter
=
self
.
config
.
add_adapter
if
add_adapter
is
None
else
add_adapter
def
_conv_out_length
(
input_length
,
kernel_size
,
stride
):
def
_conv_out_length
(
input_length
,
kernel_size
,
stride
):
# 1D convolutional layer output length formula taken
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
...
@@ -928,6 +1006,10 @@ class FlaxWav2Vec2Module(nn.Module):
...
@@ -928,6 +1006,10 @@ class FlaxWav2Vec2Module(nn.Module):
for
kernel_size
,
stride
in
zip
(
self
.
config
.
conv_kernel
,
self
.
config
.
conv_stride
):
for
kernel_size
,
stride
in
zip
(
self
.
config
.
conv_kernel
,
self
.
config
.
conv_stride
):
input_lengths
=
_conv_out_length
(
input_lengths
,
kernel_size
,
stride
)
input_lengths
=
_conv_out_length
(
input_lengths
,
kernel_size
,
stride
)
if
add_adapter
:
for
_
in
range
(
self
.
config
.
num_adapter_layers
):
input_lengths
=
_conv_out_length
(
input_lengths
,
1
,
self
.
config
.
adapter_stride
)
return
input_lengths
return
input_lengths
...
@@ -1021,11 +1103,17 @@ class FlaxWav2Vec2ForCTCModule(nn.Module):
...
@@ -1021,11 +1103,17 @@ class FlaxWav2Vec2ForCTCModule(nn.Module):
return
FlaxCausalLMOutput
(
logits
=
logits
,
hidden_states
=
outputs
.
hidden_states
,
attentions
=
outputs
.
attentions
)
return
FlaxCausalLMOutput
(
logits
=
logits
,
hidden_states
=
outputs
.
hidden_states
,
attentions
=
outputs
.
attentions
)
def
_get_feat_extract_output_lengths
(
self
,
input_lengths
:
Union
[
jnp
.
ndarray
,
int
]):
def
_get_feat_extract_output_lengths
(
self
,
input_lengths
:
Union
[
jnp
.
ndarray
,
int
],
add_adapter
:
Optional
[
bool
]
=
None
,
):
"""
"""
Computes the output length of the convolutional layers
Computes the output length of the convolutional layers
"""
"""
add_adapter
=
self
.
config
.
add_adapter
if
add_adapter
is
None
else
add_adapter
def
_conv_out_length
(
input_length
,
kernel_size
,
stride
):
def
_conv_out_length
(
input_length
,
kernel_size
,
stride
):
# 1D convolutional layer output length formula taken
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
...
@@ -1034,6 +1122,10 @@ class FlaxWav2Vec2ForCTCModule(nn.Module):
...
@@ -1034,6 +1122,10 @@ class FlaxWav2Vec2ForCTCModule(nn.Module):
for
kernel_size
,
stride
in
zip
(
self
.
config
.
conv_kernel
,
self
.
config
.
conv_stride
):
for
kernel_size
,
stride
in
zip
(
self
.
config
.
conv_kernel
,
self
.
config
.
conv_stride
):
input_lengths
=
_conv_out_length
(
input_lengths
,
kernel_size
,
stride
)
input_lengths
=
_conv_out_length
(
input_lengths
,
kernel_size
,
stride
)
if
add_adapter
:
for
_
in
range
(
self
.
config
.
num_adapter_layers
):
input_lengths
=
_conv_out_length
(
input_lengths
,
1
,
self
.
config
.
adapter_stride
)
return
input_lengths
return
input_lengths
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment