Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e0d31a89
Unverified
Commit
e0d31a89
authored
Sep 26, 2021
by
Anton Lozhkov
Committed by
GitHub
Sep 26, 2021
Browse files
[Tests] Cast Hubert test models to fp16 (#13755)
parent
400c5a15
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
35 additions
and
17 deletions
+35
-17
tests/test_modeling_hubert.py
tests/test_modeling_hubert.py
+35
-17
No files found.
tests/test_modeling_hubert.py
View file @
e0d31a89
...
@@ -635,14 +635,16 @@ class HubertModelIntegrationTest(unittest.TestCase):
...
@@ -635,14 +635,16 @@ class HubertModelIntegrationTest(unittest.TestCase):
return
ds
[:
num_samples
]
return
ds
[:
num_samples
]
def
test_inference_ctc_batched
(
self
):
def
test_inference_ctc_batched
(
self
):
model
=
HubertForCTC
.
from_pretrained
(
"facebook/hubert-large-ls960-ft"
).
to
(
torch_device
)
model
=
HubertForCTC
.
from_pretrained
(
"facebook/hubert-large-ls960-ft"
,
torch_dtype
=
torch
.
float16
).
to
(
torch_device
)
processor
=
Wav2Vec2Processor
.
from_pretrained
(
"facebook/hubert-large-ls960-ft"
,
do_lower_case
=
True
)
processor
=
Wav2Vec2Processor
.
from_pretrained
(
"facebook/hubert-large-ls960-ft"
,
do_lower_case
=
True
)
input_speech
=
self
.
_load_datasamples
(
2
)
input_speech
=
self
.
_load_datasamples
(
2
)
inputs
=
processor
(
input_speech
,
return_tensors
=
"pt"
,
padding
=
True
)
inputs
=
processor
(
input_speech
,
return_tensors
=
"pt"
,
padding
=
True
)
input_values
=
inputs
.
input_values
.
to
(
torch_device
)
input_values
=
inputs
.
input_values
.
half
().
to
(
torch_device
)
attention_mask
=
inputs
.
attention_mask
.
to
(
torch_device
)
attention_mask
=
inputs
.
attention_mask
.
to
(
torch_device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -658,12 +660,14 @@ class HubertModelIntegrationTest(unittest.TestCase):
...
@@ -658,12 +660,14 @@ class HubertModelIntegrationTest(unittest.TestCase):
self
.
assertListEqual
(
predicted_trans
,
EXPECTED_TRANSCRIPTIONS
)
self
.
assertListEqual
(
predicted_trans
,
EXPECTED_TRANSCRIPTIONS
)
def
test_inference_keyword_spotting
(
self
):
def
test_inference_keyword_spotting
(
self
):
model
=
HubertForSequenceClassification
.
from_pretrained
(
"superb/hubert-base-superb-ks"
).
to
(
torch_device
)
model
=
HubertForSequenceClassification
.
from_pretrained
(
"superb/hubert-base-superb-ks"
,
torch_dtype
=
torch
.
float16
).
to
(
torch_device
)
processor
=
Wav2Vec2FeatureExtractor
.
from_pretrained
(
"superb/hubert-base-superb-ks"
)
processor
=
Wav2Vec2FeatureExtractor
.
from_pretrained
(
"superb/hubert-base-superb-ks"
)
input_data
=
self
.
_load_superb
(
"ks"
,
4
)
input_data
=
self
.
_load_superb
(
"ks"
,
4
)
inputs
=
processor
(
input_data
[
"speech"
],
return_tensors
=
"pt"
,
padding
=
True
)
inputs
=
processor
(
input_data
[
"speech"
],
return_tensors
=
"pt"
,
padding
=
True
)
input_values
=
inputs
.
input_values
.
to
(
torch_device
)
input_values
=
inputs
.
input_values
.
half
().
to
(
torch_device
)
attention_mask
=
inputs
.
attention_mask
.
to
(
torch_device
)
attention_mask
=
inputs
.
attention_mask
.
to
(
torch_device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
outputs
=
model
(
input_values
,
attention_mask
=
attention_mask
)
outputs
=
model
(
input_values
,
attention_mask
=
attention_mask
)
...
@@ -671,18 +675,20 @@ class HubertModelIntegrationTest(unittest.TestCase):
...
@@ -671,18 +675,20 @@ class HubertModelIntegrationTest(unittest.TestCase):
expected_labels
=
[
2
,
6
,
10
,
9
]
expected_labels
=
[
2
,
6
,
10
,
9
]
# s3prl logits for the same batch
# s3prl logits for the same batch
expected_logits
=
torch
.
tensor
([
7.6692
,
17.7795
,
11.1562
,
11.8232
],
device
=
torch_device
)
expected_logits
=
torch
.
tensor
([
7.6692
,
17.7795
,
11.1562
,
11.8232
],
dtype
=
torch
.
float16
,
device
=
torch_device
)
self
.
assertListEqual
(
predicted_ids
.
tolist
(),
expected_labels
)
self
.
assertListEqual
(
predicted_ids
.
tolist
(),
expected_labels
)
self
.
assertTrue
(
torch
.
allclose
(
predicted_logits
,
expected_logits
,
atol
=
1
e-2
))
self
.
assertTrue
(
torch
.
allclose
(
predicted_logits
,
expected_logits
,
atol
=
2
e-2
))
def
test_inference_intent_classification
(
self
):
def
test_inference_intent_classification
(
self
):
model
=
HubertForSequenceClassification
.
from_pretrained
(
"superb/hubert-base-superb-ic"
).
to
(
torch_device
)
model
=
HubertForSequenceClassification
.
from_pretrained
(
"superb/hubert-base-superb-ic"
,
torch_dtype
=
torch
.
float16
).
to
(
torch_device
)
processor
=
Wav2Vec2FeatureExtractor
.
from_pretrained
(
"superb/hubert-base-superb-ic"
)
processor
=
Wav2Vec2FeatureExtractor
.
from_pretrained
(
"superb/hubert-base-superb-ic"
)
input_data
=
self
.
_load_superb
(
"ic"
,
4
)
input_data
=
self
.
_load_superb
(
"ic"
,
4
)
inputs
=
processor
(
input_data
[
"speech"
],
return_tensors
=
"pt"
,
padding
=
True
)
inputs
=
processor
(
input_data
[
"speech"
],
return_tensors
=
"pt"
,
padding
=
True
)
input_values
=
inputs
.
input_values
.
to
(
torch_device
)
input_values
=
inputs
.
input_values
.
half
().
to
(
torch_device
)
attention_mask
=
inputs
.
attention_mask
.
to
(
torch_device
)
attention_mask
=
inputs
.
attention_mask
.
to
(
torch_device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
outputs
=
model
(
input_values
,
attention_mask
=
attention_mask
)
outputs
=
model
(
input_values
,
attention_mask
=
attention_mask
)
...
@@ -692,11 +698,17 @@ class HubertModelIntegrationTest(unittest.TestCase):
...
@@ -692,11 +698,17 @@ class HubertModelIntegrationTest(unittest.TestCase):
predicted_logits_location
,
predicted_ids_location
=
torch
.
max
(
outputs
.
logits
[:,
20
:
24
],
dim
=-
1
)
predicted_logits_location
,
predicted_ids_location
=
torch
.
max
(
outputs
.
logits
[:,
20
:
24
],
dim
=-
1
)
expected_labels_action
=
[
1
,
0
,
4
,
3
]
expected_labels_action
=
[
1
,
0
,
4
,
3
]
expected_logits_action
=
torch
.
tensor
([
5.9052
,
12.5865
,
4.4840
,
10.0240
],
device
=
torch_device
)
expected_logits_action
=
torch
.
tensor
(
[
5.9052
,
12.5865
,
4.4840
,
10.0240
],
dtype
=
torch
.
float16
,
device
=
torch_device
)
expected_labels_object
=
[
1
,
10
,
3
,
4
]
expected_labels_object
=
[
1
,
10
,
3
,
4
]
expected_logits_object
=
torch
.
tensor
([
5.5316
,
11.7946
,
8.1672
,
23.2415
],
device
=
torch_device
)
expected_logits_object
=
torch
.
tensor
(
[
5.5316
,
11.7946
,
8.1672
,
23.2415
],
dtype
=
torch
.
float16
,
device
=
torch_device
)
expected_labels_location
=
[
0
,
0
,
0
,
1
]
expected_labels_location
=
[
0
,
0
,
0
,
1
]
expected_logits_location
=
torch
.
tensor
([
5.2053
,
8.9577
,
10.0447
,
8.1481
],
device
=
torch_device
)
expected_logits_location
=
torch
.
tensor
(
[
5.2053
,
8.9577
,
10.0447
,
8.1481
],
dtype
=
torch
.
float16
,
device
=
torch_device
)
self
.
assertListEqual
(
predicted_ids_action
.
tolist
(),
expected_labels_action
)
self
.
assertListEqual
(
predicted_ids_action
.
tolist
(),
expected_labels_action
)
self
.
assertListEqual
(
predicted_ids_object
.
tolist
(),
expected_labels_object
)
self
.
assertListEqual
(
predicted_ids_object
.
tolist
(),
expected_labels_object
)
...
@@ -708,7 +720,9 @@ class HubertModelIntegrationTest(unittest.TestCase):
...
@@ -708,7 +720,9 @@ class HubertModelIntegrationTest(unittest.TestCase):
self
.
assertTrue
(
torch
.
allclose
(
predicted_logits_location
,
expected_logits_location
,
atol
=
3e-1
))
self
.
assertTrue
(
torch
.
allclose
(
predicted_logits_location
,
expected_logits_location
,
atol
=
3e-1
))
def
test_inference_speaker_identification
(
self
):
def
test_inference_speaker_identification
(
self
):
model
=
HubertForSequenceClassification
.
from_pretrained
(
"superb/hubert-base-superb-sid"
).
to
(
torch_device
)
model
=
HubertForSequenceClassification
.
from_pretrained
(
"superb/hubert-base-superb-sid"
,
torch_dtype
=
torch
.
float16
).
to
(
torch_device
)
processor
=
Wav2Vec2FeatureExtractor
.
from_pretrained
(
"superb/hubert-base-superb-sid"
)
processor
=
Wav2Vec2FeatureExtractor
.
from_pretrained
(
"superb/hubert-base-superb-sid"
)
input_data
=
self
.
_load_superb
(
"si"
,
4
)
input_data
=
self
.
_load_superb
(
"si"
,
4
)
...
@@ -716,26 +730,30 @@ class HubertModelIntegrationTest(unittest.TestCase):
...
@@ -716,26 +730,30 @@ class HubertModelIntegrationTest(unittest.TestCase):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
for
example
in
input_data
[
"speech"
]:
for
example
in
input_data
[
"speech"
]:
input
=
processor
(
example
,
return_tensors
=
"pt"
,
padding
=
True
)
input
=
processor
(
example
,
return_tensors
=
"pt"
,
padding
=
True
)
output
=
model
(
input
.
input_values
.
to
(
torch_device
),
attention_mask
=
None
)
output
=
model
(
input
.
input_values
.
half
().
to
(
torch_device
),
attention_mask
=
None
)
output_logits
.
append
(
output
.
logits
[
0
])
output_logits
.
append
(
output
.
logits
[
0
])
output_logits
=
torch
.
stack
(
output_logits
)
output_logits
=
torch
.
stack
(
output_logits
)
predicted_logits
,
predicted_ids
=
torch
.
max
(
output_logits
,
dim
=-
1
)
predicted_logits
,
predicted_ids
=
torch
.
max
(
output_logits
,
dim
=-
1
)
expected_labels
=
[
5
,
1
,
1
,
3
]
expected_labels
=
[
5
,
1
,
1
,
3
]
# s3prl logits for the same batch
# s3prl logits for the same batch
expected_logits
=
torch
.
tensor
([
78231.5547
,
123166.6094
,
122785.4141
,
84851.2969
],
device
=
torch_device
)
expected_logits
=
torch
.
tensor
(
[
78231.5547
,
123166.6094
,
122785.4141
,
84851.2969
],
dtype
=
torch
.
float16
,
device
=
torch_device
)
self
.
assertListEqual
(
predicted_ids
.
tolist
(),
expected_labels
)
self
.
assertListEqual
(
predicted_ids
.
tolist
(),
expected_labels
)
# TODO: lower the tolerance after merging the padding fix https://github.com/pytorch/fairseq/pull/3572
# TODO: lower the tolerance after merging the padding fix https://github.com/pytorch/fairseq/pull/3572
self
.
assertTrue
(
torch
.
allclose
(
predicted_logits
,
expected_logits
,
atol
=
10
))
self
.
assertTrue
(
torch
.
allclose
(
predicted_logits
,
expected_logits
,
atol
=
10
))
def
test_inference_emotion_recognition
(
self
):
def
test_inference_emotion_recognition
(
self
):
model
=
HubertForSequenceClassification
.
from_pretrained
(
"superb/hubert-base-superb-er"
).
to
(
torch_device
)
model
=
HubertForSequenceClassification
.
from_pretrained
(
"superb/hubert-base-superb-er"
,
torch_dtype
=
torch
.
float16
).
to
(
torch_device
)
processor
=
Wav2Vec2FeatureExtractor
.
from_pretrained
(
"superb/hubert-base-superb-er"
)
processor
=
Wav2Vec2FeatureExtractor
.
from_pretrained
(
"superb/hubert-base-superb-er"
)
input_data
=
self
.
_load_superb
(
"er"
,
4
)
input_data
=
self
.
_load_superb
(
"er"
,
4
)
inputs
=
processor
(
input_data
[
"speech"
],
return_tensors
=
"pt"
,
padding
=
True
)
inputs
=
processor
(
input_data
[
"speech"
],
return_tensors
=
"pt"
,
padding
=
True
)
input_values
=
inputs
.
input_values
.
to
(
torch_device
)
input_values
=
inputs
.
input_values
.
half
().
to
(
torch_device
)
attention_mask
=
inputs
.
attention_mask
.
to
(
torch_device
)
attention_mask
=
inputs
.
attention_mask
.
to
(
torch_device
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
outputs
=
model
(
input_values
,
attention_mask
=
attention_mask
)
outputs
=
model
(
input_values
,
attention_mask
=
attention_mask
)
...
@@ -743,7 +761,7 @@ class HubertModelIntegrationTest(unittest.TestCase):
...
@@ -743,7 +761,7 @@ class HubertModelIntegrationTest(unittest.TestCase):
expected_labels
=
[
1
,
1
,
2
,
2
]
expected_labels
=
[
1
,
1
,
2
,
2
]
# s3prl logits for the same batch
# s3prl logits for the same batch
expected_logits
=
torch
.
tensor
([
2.8384
,
2.3389
,
3.8564
,
4.5558
],
device
=
torch_device
)
expected_logits
=
torch
.
tensor
([
2.8384
,
2.3389
,
3.8564
,
4.5558
],
dtype
=
torch
.
float16
,
device
=
torch_device
)
self
.
assertListEqual
(
predicted_ids
.
tolist
(),
expected_labels
)
self
.
assertListEqual
(
predicted_ids
.
tolist
(),
expected_labels
)
# TODO: lower the tolerance after merging the padding fix https://github.com/pytorch/fairseq/pull/3572
# TODO: lower the tolerance after merging the padding fix https://github.com/pytorch/fairseq/pull/3572
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment