Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
56b03c96
"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "c8b6052ff681e3ca8dab168dfd524b9fbbceb5bd"
Unverified
Commit
56b03c96
authored
Feb 13, 2023
by
Joao Gante
Committed by
GitHub
Feb 13, 2023
Browse files
Fix TF CTC tests (#21606)
parent
cbecf121
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
9 deletions
+21
-9
tests/models/hubert/test_modeling_tf_hubert.py
tests/models/hubert/test_modeling_tf_hubert.py
+19
-7
tests/models/wav2vec2/test_modeling_tf_wav2vec2.py
tests/models/wav2vec2/test_modeling_tf_wav2vec2.py
+2
-2
No files found.
tests/models/hubert/test_modeling_tf_hubert.py
View file @
56b03c96
...
@@ -321,6 +321,20 @@ class TFHubertModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -321,6 +321,20 @@ class TFHubertModelTest(TFModelTesterMixin, unittest.TestCase):
model
=
TFHubertModel
.
from_pretrained
(
"facebook/hubert-base-ls960"
)
model
=
TFHubertModel
.
from_pretrained
(
"facebook/hubert-base-ls960"
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
def
test_dataset_conversion
(
self
):
default_batch_size
=
self
.
model_tester
.
batch_size
self
.
model_tester
.
batch_size
=
2
super
().
test_dataset_conversion
()
self
.
model_tester
.
batch_size
=
default_batch_size
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
def
test_keras_fit
(
self
):
default_batch_size
=
self
.
model_tester
.
batch_size
self
.
model_tester
.
batch_size
=
2
super
().
test_keras_fit
()
self
.
model_tester
.
batch_size
=
default_batch_size
@
require_tf
@
require_tf
class
TFHubertRobustModelTest
(
TFModelTesterMixin
,
unittest
.
TestCase
):
class
TFHubertRobustModelTest
(
TFModelTesterMixin
,
unittest
.
TestCase
):
...
@@ -431,20 +445,18 @@ class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -431,20 +445,18 @@ class TFHubertRobustModelTest(TFModelTesterMixin, unittest.TestCase):
def
test_model_common_attributes
(
self
):
def
test_model_common_attributes
(
self
):
pass
pass
@
slow
def
test_model_from_pretrained
(
self
):
model
=
TFHubertModel
.
from_pretrained
(
"facebook/hubert-large-ls960-ft"
)
self
.
assertIsNotNone
(
model
)
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
# TODO: fix me
@
unittest
.
skip
(
reason
=
"Crashing on CI, temporarily skipped"
)
def
test_dataset_conversion
(
self
):
def
test_dataset_conversion
(
self
):
default_batch_size
=
self
.
model_tester
.
batch_size
default_batch_size
=
self
.
model_tester
.
batch_size
self
.
model_tester
.
batch_size
=
2
self
.
model_tester
.
batch_size
=
2
super
().
test_dataset_conversion
()
super
().
test_dataset_conversion
()
self
.
model_tester
.
batch_size
=
default_batch_size
self
.
model_tester
.
batch_size
=
default_batch_size
@
slow
def
test_model_from_pretrained
(
self
):
model
=
TFHubertModel
.
from_pretrained
(
"facebook/hubert-large-ls960-ft"
)
self
.
assertIsNotNone
(
model
)
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
# We override here as passing a full batch of 13 samples results in OOM errors for CTC
def
test_keras_fit
(
self
):
def
test_keras_fit
(
self
):
default_batch_size
=
self
.
model_tester
.
batch_size
default_batch_size
=
self
.
model_tester
.
batch_size
...
...
tests/models/wav2vec2/test_modeling_tf_wav2vec2.py
View file @
56b03c96
...
@@ -396,7 +396,7 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -396,7 +396,7 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, unittest.TestCase):
def
test_keras_fit
(
self
):
def
test_keras_fit
(
self
):
default_batch_size
=
self
.
model_tester
.
batch_size
default_batch_size
=
self
.
model_tester
.
batch_size
self
.
model_tester
.
batch_size
=
2
self
.
model_tester
.
batch_size
=
2
super
().
test_
dataset_conversion
()
super
().
test_
keras_fit
()
self
.
model_tester
.
batch_size
=
default_batch_size
self
.
model_tester
.
batch_size
=
default_batch_size
...
@@ -527,7 +527,7 @@ class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -527,7 +527,7 @@ class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
def
test_keras_fit
(
self
):
def
test_keras_fit
(
self
):
default_batch_size
=
self
.
model_tester
.
batch_size
default_batch_size
=
self
.
model_tester
.
batch_size
self
.
model_tester
.
batch_size
=
2
self
.
model_tester
.
batch_size
=
2
super
().
test_
dataset_conversion
()
super
().
test_
keras_fit
()
self
.
model_tester
.
batch_size
=
default_batch_size
self
.
model_tester
.
batch_size
=
default_batch_size
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment