Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
bc6f51e5
Unverified
Commit
bc6f51e5
authored
Jun 09, 2021
by
Patrick von Platen
Committed by
GitHub
Jun 09, 2021
Browse files
[Wav2Vec2ForPretraining] Correct checkpoints wav2vec2 & fix tests (#12089)
* fix_torch_device_generate_test * remove @ * fix tests
parent
61e19198
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
8 deletions
+13
-8
tests/test_modeling_wav2vec2.py
tests/test_modeling_wav2vec2.py
+13
-8
No files found.
tests/test_modeling_wav2vec2.py
View file @
bc6f51e5
...
...
@@ -349,6 +349,8 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
module
.
bias
.
data
.
fill_
(
3
)
if
hasattr
(
module
,
"codevectors"
)
and
module
.
codevectors
is
not
None
:
module
.
codevectors
.
data
.
fill_
(
3
)
if
hasattr
(
module
,
"masked_spec_embed"
)
and
module
.
masked_spec_embed
is
not
None
:
module
.
masked_spec_embed
.
data
.
fill_
(
3
)
@
slow
def
test_model_from_pretrained
(
self
):
...
...
@@ -487,6 +489,8 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
module
.
bias
.
data
.
fill_
(
3
)
if
hasattr
(
module
,
"codevectors"
)
and
module
.
codevectors
is
not
None
:
module
.
codevectors
.
data
.
fill_
(
3
)
if
hasattr
(
module
,
"masked_spec_embed"
)
and
module
.
masked_spec_embed
is
not
None
:
module
.
masked_spec_embed
.
data
.
fill_
(
3
)
def
test_model_for_pretraining
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
...
@@ -677,10 +681,10 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
self
.
assertListEqual
(
predicted_trans
,
EXPECTED_TRANSCRIPTIONS
)
def
test_inference_integration
(
self
):
model
=
Wav2Vec2ForPreTraining
.
from_pretrained
(
"
patrickvonplaten
/wav2vec2-base"
)
model
=
Wav2Vec2ForPreTraining
.
from_pretrained
(
"
facebook
/wav2vec2-base"
)
model
.
to
(
torch_device
)
feature_extractor
=
Wav2Vec2FeatureExtractor
.
from_pretrained
(
"
patrickvonplaten
/wav2vec2-base"
,
return_attention_mask
=
True
"
facebook
/wav2vec2-base"
,
return_attention_mask
=
True
)
input_speech
=
self
.
_load_datasamples
(
2
)
...
...
@@ -723,10 +727,10 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
self
.
assertTrue
(
torch
.
allclose
(
cosine_sim_masked
,
expected_cosine_sim_masked
,
atol
=
1e-3
))
def
test_inference_pretrained
(
self
):
model
=
Wav2Vec2ForPreTraining
.
from_pretrained
(
"
patrickvonplaten
/wav2vec2-base"
)
model
=
Wav2Vec2ForPreTraining
.
from_pretrained
(
"
facebook
/wav2vec2-base"
)
model
.
to
(
torch_device
)
feature_extractor
=
Wav2Vec2FeatureExtractor
.
from_pretrained
(
"
patrickvonplaten
/wav2vec2-base"
,
return_attention_mask
=
True
"
facebook
/wav2vec2-base"
,
return_attention_mask
=
True
)
input_speech
=
self
.
_load_datasamples
(
2
)
...
...
@@ -761,7 +765,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
# ... now compare to randomly initialized model
config
=
Wav2Vec2Config
.
from_pretrained
(
"
patrickvonplaten
/wav2vec2-base"
)
config
=
Wav2Vec2Config
.
from_pretrained
(
"
facebook
/wav2vec2-base"
)
model_rand
=
Wav2Vec2ForPreTraining
(
config
).
to
(
torch_device
).
eval
()
with
torch
.
no_grad
():
...
...
@@ -785,9 +789,10 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
# => the cosine similarity between quantized states and predicted states is very likely < 0.1
self
.
assertTrue
(
cosine_sim_masked
.
mean
().
item
()
-
5
*
cosine_sim_masked_rand
.
mean
().
item
()
>
0
)
@
unittest
.
skipIf
(
torch_device
!=
"cpu"
,
"cannot make deterministic on GPU"
)
def
test_loss_pretraining
(
self
):
model
=
Wav2Vec2ForPreTraining
.
from_pretrained
(
"
patrickvonplaten
/wav2vec2-base"
,
"
facebook
/wav2vec2-base"
,
attention_dropout
=
0.0
,
feat_proj_dropout
=
0.0
,
hidden_dropout
=
0.0
,
...
...
@@ -796,7 +801,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
model
.
to
(
torch_device
).
train
()
feature_extractor
=
Wav2Vec2FeatureExtractor
.
from_pretrained
(
"
patrickvonplaten
/wav2vec2-base"
,
return_attention_mask
=
True
"
facebook
/wav2vec2-base"
,
return_attention_mask
=
True
)
input_speech
=
self
.
_load_datasamples
(
2
)
...
...
@@ -829,6 +834,6 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
self
.
assertTrue
(
abs
(
diversity_loss
.
item
()
-
0.8859
)
<
1e-3
)
# check overall loss (contrastive loss + diversity loss)
expected_loss
=
62.5170
if
model
.
device
.
type
==
"cpu"
else
50.3612
expected_loss
=
62.5170
self
.
assertTrue
(
abs
(
outputs
.
loss
.
item
()
-
expected_loss
)
<
1e-3
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment