Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e26c6f03
Unverified
Commit
e26c6f03
authored
Jun 12, 2023
by
Yih-Dar
Committed by
GitHub
Jun 12, 2023
Browse files
Fix `Wav2Vec2` CI OOM (#24190)
fix Co-authored-by:
ydshieh
<
ydshieh@users.noreply.github.com
>
parent
8f093fb7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
0 deletions
+13
-0
tests/models/wav2vec2/test_modeling_tf_wav2vec2.py
tests/models/wav2vec2/test_modeling_tf_wav2vec2.py
+6
-0
tests/models/wav2vec2/test_modeling_wav2vec2.py
tests/models/wav2vec2/test_modeling_wav2vec2.py
+7
-0
No files found.
tests/models/wav2vec2/test_modeling_tf_wav2vec2.py
View file @
e26c6f03
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
from
__future__
import
annotations
from
__future__
import
annotations
import
copy
import
copy
import
gc
import
glob
import
glob
import
inspect
import
inspect
import
math
import
math
...
@@ -709,6 +710,11 @@ class TFWav2Vec2UtilsTest(unittest.TestCase):
...
@@ -709,6 +710,11 @@ class TFWav2Vec2UtilsTest(unittest.TestCase):
@
require_tf
@
require_tf
@
slow
@
slow
class
TFWav2Vec2ModelIntegrationTest
(
unittest
.
TestCase
):
class
TFWav2Vec2ModelIntegrationTest
(
unittest
.
TestCase
):
def
tearDown
(
self
):
super
().
tearDown
()
# clean-up as much as possible GPU memory occupied by PyTorch
gc
.
collect
()
def
_load_datasamples
(
self
,
num_samples
):
def
_load_datasamples
(
self
,
num_samples
):
ds
=
load_dataset
(
"hf-internal-testing/librispeech_asr_dummy"
,
"clean"
,
split
=
"validation"
)
ds
=
load_dataset
(
"hf-internal-testing/librispeech_asr_dummy"
,
"clean"
,
split
=
"validation"
)
# automatic decoding with librispeech
# automatic decoding with librispeech
...
...
tests/models/wav2vec2/test_modeling_wav2vec2.py
View file @
e26c6f03
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
# limitations under the License.
# limitations under the License.
""" Testing suite for the PyTorch Wav2Vec2 model. """
""" Testing suite for the PyTorch Wav2Vec2 model. """
import
gc
import
math
import
math
import
multiprocessing
import
multiprocessing
import
os
import
os
...
@@ -1374,6 +1375,12 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
...
@@ -1374,6 +1375,12 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
@
require_soundfile
@
require_soundfile
@
slow
@
slow
class
Wav2Vec2ModelIntegrationTest
(
unittest
.
TestCase
):
class
Wav2Vec2ModelIntegrationTest
(
unittest
.
TestCase
):
def
tearDown
(
self
):
super
().
tearDown
()
# clean-up as much as possible GPU memory occupied by PyTorch
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
def
_load_datasamples
(
self
,
num_samples
):
def
_load_datasamples
(
self
,
num_samples
):
ds
=
load_dataset
(
"hf-internal-testing/librispeech_asr_dummy"
,
"clean"
,
split
=
"validation"
)
ds
=
load_dataset
(
"hf-internal-testing/librispeech_asr_dummy"
,
"clean"
,
split
=
"validation"
)
# automatic decoding with librispeech
# automatic decoding with librispeech
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment