Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
ee55ea69
"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "b72752f06830cb6cf8d21c284f68e15faa100c4d"
Unverified
Commit
ee55ea69
authored
Dec 23, 2021
by
Anton Lozhkov
Committed by
GitHub
Dec 23, 2021
Browse files
Update diarization and WavLM tolerances (#14902)
parent
ef47d4f8
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
7 deletions
+14
-7
tests/test_modeling_unispeech_sat.py
tests/test_modeling_unispeech_sat.py
+4
-2
tests/test_modeling_wav2vec2.py
tests/test_modeling_wav2vec2.py
+4
-2
tests/test_modeling_wavlm.py
tests/test_modeling_wavlm.py
+6
-3
No files found.
tests/test_modeling_unispeech_sat.py
View file @
ee55ea69
...
@@ -889,7 +889,8 @@ class UniSpeechSatModelIntegrationTest(unittest.TestCase):
...
@@ -889,7 +889,8 @@ class UniSpeechSatModelIntegrationTest(unittest.TestCase):
)
)
self
.
assertEqual
(
labels
[
0
,
:,
0
].
sum
(),
270
)
self
.
assertEqual
(
labels
[
0
,
:,
0
].
sum
(),
270
)
self
.
assertEqual
(
labels
[
0
,
:,
1
].
sum
(),
647
)
self
.
assertEqual
(
labels
[
0
,
:,
1
].
sum
(),
647
)
self
.
assertTrue
(
torch
.
allclose
(
outputs
.
logits
[:,
:
4
],
expected_logits
,
atol
=
1e-3
))
# TODO: update the tolerance after the CI moves to torch 1.10
self
.
assertTrue
(
torch
.
allclose
(
outputs
.
logits
[:,
:
4
],
expected_logits
,
atol
=
1e-2
))
def
test_inference_speaker_verification
(
self
):
def
test_inference_speaker_verification
(
self
):
model
=
UniSpeechSatForXVector
.
from_pretrained
(
"microsoft/unispeech-sat-base-plus-sv"
).
to
(
torch_device
)
model
=
UniSpeechSatForXVector
.
from_pretrained
(
"microsoft/unispeech-sat-base-plus-sv"
).
to
(
torch_device
)
...
@@ -913,4 +914,5 @@ class UniSpeechSatModelIntegrationTest(unittest.TestCase):
...
@@ -913,4 +914,5 @@ class UniSpeechSatModelIntegrationTest(unittest.TestCase):
# id10002 vs id10004
# id10002 vs id10004
self
.
assertAlmostEqual
(
cosine_sim
(
embeddings
[
2
],
embeddings
[
3
]).
item
(),
0.5616
,
3
)
self
.
assertAlmostEqual
(
cosine_sim
(
embeddings
[
2
],
embeddings
[
3
]).
item
(),
0.5616
,
3
)
self
.
assertAlmostEqual
(
outputs
.
loss
.
item
(),
18.5925
,
3
)
# TODO: update the tolerance after the CI moves to torch 1.10
self
.
assertAlmostEqual
(
outputs
.
loss
.
item
(),
18.5925
,
2
)
tests/test_modeling_wav2vec2.py
View file @
ee55ea69
...
@@ -1480,7 +1480,8 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
...
@@ -1480,7 +1480,8 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
)
)
self
.
assertEqual
(
labels
[
0
,
:,
0
].
sum
(),
555
)
self
.
assertEqual
(
labels
[
0
,
:,
0
].
sum
(),
555
)
self
.
assertEqual
(
labels
[
0
,
:,
1
].
sum
(),
299
)
self
.
assertEqual
(
labels
[
0
,
:,
1
].
sum
(),
299
)
self
.
assertTrue
(
torch
.
allclose
(
outputs
.
logits
[:,
:
4
],
expected_logits
,
atol
=
1e-3
))
# TODO: update the tolerance after the CI moves to torch 1.10
self
.
assertTrue
(
torch
.
allclose
(
outputs
.
logits
[:,
:
4
],
expected_logits
,
atol
=
1e-2
))
def
test_inference_speaker_verification
(
self
):
def
test_inference_speaker_verification
(
self
):
model
=
Wav2Vec2ForXVector
.
from_pretrained
(
"anton-l/wav2vec2-base-superb-sv"
).
to
(
torch_device
)
model
=
Wav2Vec2ForXVector
.
from_pretrained
(
"anton-l/wav2vec2-base-superb-sv"
).
to
(
torch_device
)
...
@@ -1504,4 +1505,5 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
...
@@ -1504,4 +1505,5 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
# id10002 vs id10004
# id10002 vs id10004
self
.
assertAlmostEqual
(
cosine_sim
(
embeddings
[
2
],
embeddings
[
3
]).
numpy
(),
0.7594
,
3
)
self
.
assertAlmostEqual
(
cosine_sim
(
embeddings
[
2
],
embeddings
[
3
]).
numpy
(),
0.7594
,
3
)
self
.
assertAlmostEqual
(
outputs
.
loss
.
item
(),
17.7963
,
3
)
# TODO: update the tolerance after the CI moves to torch 1.10
self
.
assertAlmostEqual
(
outputs
.
loss
.
item
(),
17.7963
,
2
)
tests/test_modeling_wavlm.py
View file @
ee55ea69
...
@@ -496,7 +496,8 @@ class WavLMModelIntegrationTest(unittest.TestCase):
...
@@ -496,7 +496,8 @@ class WavLMModelIntegrationTest(unittest.TestCase):
EXPECTED_HIDDEN_STATES_SLICE
=
torch
.
tensor
(
EXPECTED_HIDDEN_STATES_SLICE
=
torch
.
tensor
(
[[[
0.0577
,
0.1161
],
[
0.0579
,
0.1165
]],
[[
0.0199
,
0.1237
],
[
0.0059
,
0.0605
]]]
[[[
0.0577
,
0.1161
],
[
0.0579
,
0.1165
]],
[[
0.0199
,
0.1237
],
[
0.0059
,
0.0605
]]]
)
)
self
.
assertTrue
(
torch
.
allclose
(
hidden_states_slice
,
EXPECTED_HIDDEN_STATES_SLICE
,
rtol
=
1e-2
))
# TODO: update the tolerance after the CI moves to torch 1.10
self
.
assertTrue
(
torch
.
allclose
(
hidden_states_slice
,
EXPECTED_HIDDEN_STATES_SLICE
,
atol
=
1e-2
))
def
test_inference_large
(
self
):
def
test_inference_large
(
self
):
model
=
WavLMModel
.
from_pretrained
(
"microsoft/wavlm-large"
).
to
(
torch_device
)
model
=
WavLMModel
.
from_pretrained
(
"microsoft/wavlm-large"
).
to
(
torch_device
)
...
@@ -546,7 +547,8 @@ class WavLMModelIntegrationTest(unittest.TestCase):
...
@@ -546,7 +547,8 @@ class WavLMModelIntegrationTest(unittest.TestCase):
)
)
self
.
assertEqual
(
labels
[
0
,
:,
0
].
sum
(),
258
)
self
.
assertEqual
(
labels
[
0
,
:,
0
].
sum
(),
258
)
self
.
assertEqual
(
labels
[
0
,
:,
1
].
sum
(),
647
)
self
.
assertEqual
(
labels
[
0
,
:,
1
].
sum
(),
647
)
self
.
assertTrue
(
torch
.
allclose
(
outputs
.
logits
[:,
:
4
],
expected_logits
,
atol
=
1e-3
))
# TODO: update the tolerance after the CI moves to torch 1.10
self
.
assertTrue
(
torch
.
allclose
(
outputs
.
logits
[:,
:
4
],
expected_logits
,
atol
=
1e-2
))
def
test_inference_speaker_verification
(
self
):
def
test_inference_speaker_verification
(
self
):
model
=
WavLMForXVector
.
from_pretrained
(
"microsoft/wavlm-base-plus-sv"
).
to
(
torch_device
)
model
=
WavLMForXVector
.
from_pretrained
(
"microsoft/wavlm-base-plus-sv"
).
to
(
torch_device
)
...
@@ -570,4 +572,5 @@ class WavLMModelIntegrationTest(unittest.TestCase):
...
@@ -570,4 +572,5 @@ class WavLMModelIntegrationTest(unittest.TestCase):
# id10002 vs id10004
# id10002 vs id10004
self
.
assertAlmostEqual
(
cosine_sim
(
embeddings
[
2
],
embeddings
[
3
]).
item
(),
0.4780
,
3
)
self
.
assertAlmostEqual
(
cosine_sim
(
embeddings
[
2
],
embeddings
[
3
]).
item
(),
0.4780
,
3
)
self
.
assertAlmostEqual
(
outputs
.
loss
.
item
(),
18.4154
,
3
)
# TODO: update the tolerance after the CI moves to torch 1.10
self
.
assertAlmostEqual
(
outputs
.
loss
.
item
(),
18.4154
,
2
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment