Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
bc6f51e5
Unverified
Commit
bc6f51e5
authored
Jun 09, 2021
by
Patrick von Platen
Committed by
GitHub
Jun 09, 2021
Browse files
[Wav2Vec2ForPretraining] Correct checkpoints wav2vec2 & fix tests (#12089)
* fix_torch_device_generate_test * remove @ * fix tests
parent
61e19198
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
8 deletions
+13
-8
tests/test_modeling_wav2vec2.py
tests/test_modeling_wav2vec2.py
+13
-8
No files found.
tests/test_modeling_wav2vec2.py
View file @
bc6f51e5
...
@@ -349,6 +349,8 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
...
@@ -349,6 +349,8 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
module
.
bias
.
data
.
fill_
(
3
)
module
.
bias
.
data
.
fill_
(
3
)
if
hasattr
(
module
,
"codevectors"
)
and
module
.
codevectors
is
not
None
:
if
hasattr
(
module
,
"codevectors"
)
and
module
.
codevectors
is
not
None
:
module
.
codevectors
.
data
.
fill_
(
3
)
module
.
codevectors
.
data
.
fill_
(
3
)
if
hasattr
(
module
,
"masked_spec_embed"
)
and
module
.
masked_spec_embed
is
not
None
:
module
.
masked_spec_embed
.
data
.
fill_
(
3
)
@
slow
@
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
...
@@ -487,6 +489,8 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
...
@@ -487,6 +489,8 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
module
.
bias
.
data
.
fill_
(
3
)
module
.
bias
.
data
.
fill_
(
3
)
if
hasattr
(
module
,
"codevectors"
)
and
module
.
codevectors
is
not
None
:
if
hasattr
(
module
,
"codevectors"
)
and
module
.
codevectors
is
not
None
:
module
.
codevectors
.
data
.
fill_
(
3
)
module
.
codevectors
.
data
.
fill_
(
3
)
if
hasattr
(
module
,
"masked_spec_embed"
)
and
module
.
masked_spec_embed
is
not
None
:
module
.
masked_spec_embed
.
data
.
fill_
(
3
)
def
test_model_for_pretraining
(
self
):
def
test_model_for_pretraining
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
...
@@ -677,10 +681,10 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
...
@@ -677,10 +681,10 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
self
.
assertListEqual
(
predicted_trans
,
EXPECTED_TRANSCRIPTIONS
)
self
.
assertListEqual
(
predicted_trans
,
EXPECTED_TRANSCRIPTIONS
)
def
test_inference_integration
(
self
):
def
test_inference_integration
(
self
):
model
=
Wav2Vec2ForPreTraining
.
from_pretrained
(
"
patrickvonplaten
/wav2vec2-base"
)
model
=
Wav2Vec2ForPreTraining
.
from_pretrained
(
"
facebook
/wav2vec2-base"
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
feature_extractor
=
Wav2Vec2FeatureExtractor
.
from_pretrained
(
feature_extractor
=
Wav2Vec2FeatureExtractor
.
from_pretrained
(
"
patrickvonplaten
/wav2vec2-base"
,
return_attention_mask
=
True
"
facebook
/wav2vec2-base"
,
return_attention_mask
=
True
)
)
input_speech
=
self
.
_load_datasamples
(
2
)
input_speech
=
self
.
_load_datasamples
(
2
)
...
@@ -723,10 +727,10 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
...
@@ -723,10 +727,10 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
self
.
assertTrue
(
torch
.
allclose
(
cosine_sim_masked
,
expected_cosine_sim_masked
,
atol
=
1e-3
))
self
.
assertTrue
(
torch
.
allclose
(
cosine_sim_masked
,
expected_cosine_sim_masked
,
atol
=
1e-3
))
def
test_inference_pretrained
(
self
):
def
test_inference_pretrained
(
self
):
model
=
Wav2Vec2ForPreTraining
.
from_pretrained
(
"
patrickvonplaten
/wav2vec2-base"
)
model
=
Wav2Vec2ForPreTraining
.
from_pretrained
(
"
facebook
/wav2vec2-base"
)
model
.
to
(
torch_device
)
model
.
to
(
torch_device
)
feature_extractor
=
Wav2Vec2FeatureExtractor
.
from_pretrained
(
feature_extractor
=
Wav2Vec2FeatureExtractor
.
from_pretrained
(
"
patrickvonplaten
/wav2vec2-base"
,
return_attention_mask
=
True
"
facebook
/wav2vec2-base"
,
return_attention_mask
=
True
)
)
input_speech
=
self
.
_load_datasamples
(
2
)
input_speech
=
self
.
_load_datasamples
(
2
)
...
@@ -761,7 +765,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
...
@@ -761,7 +765,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
# ... now compare to randomly initialized model
# ... now compare to randomly initialized model
config
=
Wav2Vec2Config
.
from_pretrained
(
"
patrickvonplaten
/wav2vec2-base"
)
config
=
Wav2Vec2Config
.
from_pretrained
(
"
facebook
/wav2vec2-base"
)
model_rand
=
Wav2Vec2ForPreTraining
(
config
).
to
(
torch_device
).
eval
()
model_rand
=
Wav2Vec2ForPreTraining
(
config
).
to
(
torch_device
).
eval
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -785,9 +789,10 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
...
@@ -785,9 +789,10 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
# => the cosine similarity between quantized states and predicted states is very likely < 0.1
# => the cosine similarity between quantized states and predicted states is very likely < 0.1
self
.
assertTrue
(
cosine_sim_masked
.
mean
().
item
()
-
5
*
cosine_sim_masked_rand
.
mean
().
item
()
>
0
)
self
.
assertTrue
(
cosine_sim_masked
.
mean
().
item
()
-
5
*
cosine_sim_masked_rand
.
mean
().
item
()
>
0
)
@
unittest
.
skipIf
(
torch_device
!=
"cpu"
,
"cannot make deterministic on GPU"
)
def
test_loss_pretraining
(
self
):
def
test_loss_pretraining
(
self
):
model
=
Wav2Vec2ForPreTraining
.
from_pretrained
(
model
=
Wav2Vec2ForPreTraining
.
from_pretrained
(
"
patrickvonplaten
/wav2vec2-base"
,
"
facebook
/wav2vec2-base"
,
attention_dropout
=
0.0
,
attention_dropout
=
0.0
,
feat_proj_dropout
=
0.0
,
feat_proj_dropout
=
0.0
,
hidden_dropout
=
0.0
,
hidden_dropout
=
0.0
,
...
@@ -796,7 +801,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
...
@@ -796,7 +801,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
model
.
to
(
torch_device
).
train
()
model
.
to
(
torch_device
).
train
()
feature_extractor
=
Wav2Vec2FeatureExtractor
.
from_pretrained
(
feature_extractor
=
Wav2Vec2FeatureExtractor
.
from_pretrained
(
"
patrickvonplaten
/wav2vec2-base"
,
return_attention_mask
=
True
"
facebook
/wav2vec2-base"
,
return_attention_mask
=
True
)
)
input_speech
=
self
.
_load_datasamples
(
2
)
input_speech
=
self
.
_load_datasamples
(
2
)
...
@@ -829,6 +834,6 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
...
@@ -829,6 +834,6 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
self
.
assertTrue
(
abs
(
diversity_loss
.
item
()
-
0.8859
)
<
1e-3
)
self
.
assertTrue
(
abs
(
diversity_loss
.
item
()
-
0.8859
)
<
1e-3
)
# check overall loss (contrastive loss + diversity loss)
# check overall loss (contrastive loss + diversity loss)
expected_loss
=
62.5170
if
model
.
device
.
type
==
"cpu"
else
50.3612
expected_loss
=
62.5170
self
.
assertTrue
(
abs
(
outputs
.
loss
.
item
()
-
expected_loss
)
<
1e-3
)
self
.
assertTrue
(
abs
(
outputs
.
loss
.
item
()
-
expected_loss
)
<
1e-3
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment