Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
cd19b193
Unverified
Commit
cd19b193
authored
Oct 30, 2023
by
Hz, Ji
Committed by
GitHub
Oct 30, 2023
Browse files
make tests of pytorch_example device agnostic (#27081)
parent
6b466771
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
32 additions
and
33 deletions
+32
-33
examples/pytorch/test_accelerate_examples.py
examples/pytorch/test_accelerate_examples.py
+12
-12
examples/pytorch/test_pytorch_examples.py
examples/pytorch/test_pytorch_examples.py
+20
-21
No files found.
examples/pytorch/test_accelerate_examples.py
View file @
cd19b193
...
@@ -24,11 +24,16 @@ import tempfile
...
@@ -24,11 +24,16 @@ import tempfile
import
unittest
import
unittest
from
unittest
import
mock
from
unittest
import
mock
import
torch
from
accelerate.utils
import
write_basic_config
from
accelerate.utils
import
write_basic_config
from
transformers.testing_utils
import
TestCasePlus
,
get_gpu_count
,
run_command
,
slow
,
torch_device
from
transformers.testing_utils
import
(
from
transformers.utils
import
is_apex_available
TestCasePlus
,
backend_device_count
,
is_torch_fp16_available_on_device
,
run_command
,
slow
,
torch_device
,
)
logging
.
basicConfig
(
level
=
logging
.
DEBUG
)
logging
.
basicConfig
(
level
=
logging
.
DEBUG
)
...
@@ -54,11 +59,6 @@ def get_results(output_dir):
...
@@ -54,11 +59,6 @@ def get_results(output_dir):
return
results
return
results
def
is_cuda_and_apex_available
():
is_using_cuda
=
torch
.
cuda
.
is_available
()
and
torch_device
==
"cuda"
return
is_using_cuda
and
is_apex_available
()
stream_handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
stream_handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
logger
.
addHandler
(
stream_handler
)
logger
.
addHandler
(
stream_handler
)
...
@@ -93,7 +93,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
...
@@ -93,7 +93,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
--with_tracking
--with_tracking
"""
.
split
()
"""
.
split
()
if
is_
cuda_and_apex_available
(
):
if
is_
torch_fp16_available_on_device
(
torch_device
):
testargs
.
append
(
"--fp16"
)
testargs
.
append
(
"--fp16"
)
run_command
(
self
.
_launch_args
+
testargs
)
run_command
(
self
.
_launch_args
+
testargs
)
...
@@ -119,7 +119,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
...
@@ -119,7 +119,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
--with_tracking
--with_tracking
"""
.
split
()
"""
.
split
()
if
torch
.
cuda
.
device_count
()
>
1
:
if
backend_
device_count
(
torch_device
)
>
1
:
# Skipping because there are not enough batches to train the model + would need a drop_last to work.
# Skipping because there are not enough batches to train the model + would need a drop_last to work.
return
return
...
@@ -152,7 +152,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
...
@@ -152,7 +152,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
@
mock
.
patch
.
dict
(
os
.
environ
,
{
"WANDB_MODE"
:
"offline"
})
@
mock
.
patch
.
dict
(
os
.
environ
,
{
"WANDB_MODE"
:
"offline"
})
def
test_run_ner_no_trainer
(
self
):
def
test_run_ner_no_trainer
(
self
):
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
epochs
=
7
if
get_gpu_count
(
)
>
1
else
2
epochs
=
7
if
backend_device_count
(
torch_device
)
>
1
else
2
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()
testargs
=
f
"""
testargs
=
f
"""
...
@@ -326,7 +326,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
...
@@ -326,7 +326,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
--checkpointing_steps 1
--checkpointing_steps 1
"""
.
split
()
"""
.
split
()
if
is_
cuda_and_apex_available
(
):
if
is_
torch_fp16_available_on_device
(
torch_device
):
testargs
.
append
(
"--fp16"
)
testargs
.
append
(
"--fp16"
)
run_command
(
self
.
_launch_args
+
testargs
)
run_command
(
self
.
_launch_args
+
testargs
)
...
...
examples/pytorch/test_pytorch_examples.py
View file @
cd19b193
...
@@ -20,11 +20,15 @@ import os
...
@@ -20,11 +20,15 @@ import os
import
sys
import
sys
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
torch
from
transformers
import
ViTMAEForPreTraining
,
Wav2Vec2ForPreTraining
from
transformers
import
ViTMAEForPreTraining
,
Wav2Vec2ForPreTraining
from
transformers.testing_utils
import
CaptureLogger
,
TestCasePlus
,
get_gpu_count
,
slow
,
torch_device
from
transformers.testing_utils
import
(
from
transformers.utils
import
is_apex_available
CaptureLogger
,
TestCasePlus
,
backend_device_count
,
is_torch_fp16_available_on_device
,
slow
,
torch_device
,
)
SRC_DIRS
=
[
SRC_DIRS
=
[
...
@@ -86,11 +90,6 @@ def get_results(output_dir):
...
@@ -86,11 +90,6 @@ def get_results(output_dir):
return
results
return
results
def
is_cuda_and_apex_available
():
is_using_cuda
=
torch
.
cuda
.
is_available
()
and
torch_device
==
"cuda"
return
is_using_cuda
and
is_apex_available
()
stream_handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
stream_handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
logger
.
addHandler
(
stream_handler
)
logger
.
addHandler
(
stream_handler
)
...
@@ -116,7 +115,7 @@ class ExamplesTests(TestCasePlus):
...
@@ -116,7 +115,7 @@ class ExamplesTests(TestCasePlus):
--max_seq_length=128
--max_seq_length=128
"""
.
split
()
"""
.
split
()
if
is_
cuda_and_apex_available
(
):
if
is_
torch_fp16_available_on_device
(
torch_device
):
testargs
.
append
(
"--fp16"
)
testargs
.
append
(
"--fp16"
)
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
...
@@ -141,7 +140,7 @@ class ExamplesTests(TestCasePlus):
...
@@ -141,7 +140,7 @@ class ExamplesTests(TestCasePlus):
--overwrite_output_dir
--overwrite_output_dir
"""
.
split
()
"""
.
split
()
if
torch
.
cuda
.
device_count
()
>
1
:
if
backend_
device_count
(
torch_device
)
>
1
:
# Skipping because there are not enough batches to train the model + would need a drop_last to work.
# Skipping because there are not enough batches to train the model + would need a drop_last to work.
return
return
...
@@ -203,7 +202,7 @@ class ExamplesTests(TestCasePlus):
...
@@ -203,7 +202,7 @@ class ExamplesTests(TestCasePlus):
def
test_run_ner
(
self
):
def
test_run_ner
(
self
):
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
epochs
=
7
if
get_gpu_count
(
)
>
1
else
2
epochs
=
7
if
backend_device_count
(
torch_device
)
>
1
else
2
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()
testargs
=
f
"""
testargs
=
f
"""
...
@@ -312,7 +311,7 @@ class ExamplesTests(TestCasePlus):
...
@@ -312,7 +311,7 @@ class ExamplesTests(TestCasePlus):
def
test_generation
(
self
):
def
test_generation
(
self
):
testargs
=
[
"run_generation.py"
,
"--prompt=Hello"
,
"--length=10"
,
"--seed=42"
]
testargs
=
[
"run_generation.py"
,
"--prompt=Hello"
,
"--length=10"
,
"--seed=42"
]
if
is_
cuda_and_apex_available
(
):
if
is_
torch_fp16_available_on_device
(
torch_device
):
testargs
.
append
(
"--fp16"
)
testargs
.
append
(
"--fp16"
)
model_type
,
model_name
=
(
model_type
,
model_name
=
(
...
@@ -401,7 +400,7 @@ class ExamplesTests(TestCasePlus):
...
@@ -401,7 +400,7 @@ class ExamplesTests(TestCasePlus):
--seed 42
--seed 42
"""
.
split
()
"""
.
split
()
if
is_
cuda_and_apex_available
(
):
if
is_
torch_fp16_available_on_device
(
torch_device
):
testargs
.
append
(
"--fp16"
)
testargs
.
append
(
"--fp16"
)
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
...
@@ -431,7 +430,7 @@ class ExamplesTests(TestCasePlus):
...
@@ -431,7 +430,7 @@ class ExamplesTests(TestCasePlus):
--seed 42
--seed 42
"""
.
split
()
"""
.
split
()
if
is_
cuda_and_apex_available
(
):
if
is_
torch_fp16_available_on_device
(
torch_device
):
testargs
.
append
(
"--fp16"
)
testargs
.
append
(
"--fp16"
)
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
...
@@ -462,7 +461,7 @@ class ExamplesTests(TestCasePlus):
...
@@ -462,7 +461,7 @@ class ExamplesTests(TestCasePlus):
--seed 42
--seed 42
"""
.
split
()
"""
.
split
()
if
is_
cuda_and_apex_available
(
):
if
is_
torch_fp16_available_on_device
(
torch_device
):
testargs
.
append
(
"--fp16"
)
testargs
.
append
(
"--fp16"
)
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
...
@@ -493,7 +492,7 @@ class ExamplesTests(TestCasePlus):
...
@@ -493,7 +492,7 @@ class ExamplesTests(TestCasePlus):
--seed 42
--seed 42
"""
.
split
()
"""
.
split
()
if
is_
cuda_and_apex_available
(
):
if
is_
torch_fp16_available_on_device
(
torch_device
):
testargs
.
append
(
"--fp16"
)
testargs
.
append
(
"--fp16"
)
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
...
@@ -525,7 +524,7 @@ class ExamplesTests(TestCasePlus):
...
@@ -525,7 +524,7 @@ class ExamplesTests(TestCasePlus):
--seed 42
--seed 42
"""
.
split
()
"""
.
split
()
if
is_
cuda_and_apex_available
(
):
if
is_
torch_fp16_available_on_device
(
torch_device
):
testargs
.
append
(
"--fp16"
)
testargs
.
append
(
"--fp16"
)
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
...
@@ -551,7 +550,7 @@ class ExamplesTests(TestCasePlus):
...
@@ -551,7 +550,7 @@ class ExamplesTests(TestCasePlus):
--seed 42
--seed 42
"""
.
split
()
"""
.
split
()
if
is_
cuda_and_apex_available
(
):
if
is_
torch_fp16_available_on_device
(
torch_device
):
testargs
.
append
(
"--fp16"
)
testargs
.
append
(
"--fp16"
)
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
...
@@ -579,7 +578,7 @@ class ExamplesTests(TestCasePlus):
...
@@ -579,7 +578,7 @@ class ExamplesTests(TestCasePlus):
--seed 42
--seed 42
"""
.
split
()
"""
.
split
()
if
is_
cuda_and_apex_available
(
):
if
is_
torch_fp16_available_on_device
(
torch_device
):
testargs
.
append
(
"--fp16"
)
testargs
.
append
(
"--fp16"
)
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
...
@@ -604,7 +603,7 @@ class ExamplesTests(TestCasePlus):
...
@@ -604,7 +603,7 @@ class ExamplesTests(TestCasePlus):
--seed 32
--seed 32
"""
.
split
()
"""
.
split
()
if
is_
cuda_and_apex_available
(
):
if
is_
torch_fp16_available_on_device
(
torch_device
):
testargs
.
append
(
"--fp16"
)
testargs
.
append
(
"--fp16"
)
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment