Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
56b03c96
Unverified
Commit
56b03c96
authored
Feb 13, 2023
by
Joao Gante
Committed by
GitHub
Feb 13, 2023
Browse files
Fix TF CTC tests (#21606)
parent
cbecf121
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
9 deletions
+21
-9
tests/models/hubert/test_modeling_tf_hubert.py
tests/models/hubert/test_modeling_tf_hubert.py
+19
-7
tests/models/wav2vec2/test_modeling_tf_wav2vec2.py
tests/models/wav2vec2/test_modeling_tf_wav2vec2.py
+2
-2
No files found.
tests/models/hubert/test_modeling_tf_hubert.py
View file @
56b03c96
...
@@ -321,6 +321,20 @@ class TFHubertModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -321,6 +321,20 @@ class TFHubertModelTest(TFModelTesterMixin, unittest.TestCase):
model
=
TFHubertModel
.
from_pretrained
(
"facebook/hubert-base-ls960"
)
model
=
TFHubertModel
.
from_pretrained
(
"facebook/hubert-base-ls960"
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
def
test_dataset_conversion
(
self
):
default_batch_size
=
self
.
model_tester
.
batch_size
self
.
model_tester
.
batch_size
=
2
super
().
test_dataset_conversion
()
self
.
model_tester
.
batch_size
=
default_batch_size
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
def
test_keras_fit
(
self
):
default_batch_size
=
self
.
model_tester
.
batch_size
self
.
model_tester
.
batch_size
=
2
super
().
test_keras_fit
()
self
.
model_tester
.
batch_size
=
default_batch_size
@
require_tf
@
require_tf
class
TFHubertRobustModelTest
(
TFModelTesterMixin
,
unittest
.
TestCase
):
class
TFHubertRobustModelTest
(
TFModelTesterMixin
,
unittest
.
TestCase
):
...
@@ -431,20 +445,18 @@ class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -431,20 +445,18 @@ class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase):
def
test_model_common_attributes
(
self
):
def
test_model_common_attributes
(
self
):
pass
pass
@
slow
def
test_model_from_pretrained
(
self
):
model
=
TFHubertModel
.
from_pretrained
(
"facebook/hubert-large-ls960-ft"
)
self
.
assertIsNotNone
(
model
)
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
# TODO: fix me
@
unittest
.
skip
(
reason
=
"Crashing on CI, temporarily skipped"
)
def
test_dataset_conversion
(
self
):
def
test_dataset_conversion
(
self
):
default_batch_size
=
self
.
model_tester
.
batch_size
default_batch_size
=
self
.
model_tester
.
batch_size
self
.
model_tester
.
batch_size
=
2
self
.
model_tester
.
batch_size
=
2
super
().
test_dataset_conversion
()
super
().
test_dataset_conversion
()
self
.
model_tester
.
batch_size
=
default_batch_size
self
.
model_tester
.
batch_size
=
default_batch_size
@
slow
def
test_model_from_pretrained
(
self
):
model
=
TFHubertModel
.
from_pretrained
(
"facebook/hubert-large-ls960-ft"
)
self
.
assertIsNotNone
(
model
)
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
def
test_keras_fit
(
self
):
def
test_keras_fit
(
self
):
default_batch_size
=
self
.
model_tester
.
batch_size
default_batch_size
=
self
.
model_tester
.
batch_size
...
...
tests/models/wav2vec2/test_modeling_tf_wav2vec2.py
View file @
56b03c96
...
@@ -396,7 +396,7 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -396,7 +396,7 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, unittest.TestCase):
def
test_keras_fit
(
self
):
def
test_keras_fit
(
self
):
default_batch_size
=
self
.
model_tester
.
batch_size
default_batch_size
=
self
.
model_tester
.
batch_size
self
.
model_tester
.
batch_size
=
2
self
.
model_tester
.
batch_size
=
2
super
().
test_
dataset_conversion
()
super
().
test_
keras_fit
()
self
.
model_tester
.
batch_size
=
default_batch_size
self
.
model_tester
.
batch_size
=
default_batch_size
...
@@ -527,7 +527,7 @@ class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -527,7 +527,7 @@ class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
def
test_keras_fit
(
self
):
def
test_keras_fit
(
self
):
default_batch_size
=
self
.
model_tester
.
batch_size
default_batch_size
=
self
.
model_tester
.
batch_size
self
.
model_tester
.
batch_size
=
2
self
.
model_tester
.
batch_size
=
2
super
().
test_
dataset_conversion
()
super
().
test_
keras_fit
()
self
.
model_tester
.
batch_size
=
default_batch_size
self
.
model_tester
.
batch_size
=
default_batch_size
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment