Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
cd19b193
"src/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "b7729f0a56ee703f2ae3dda414b8e58f6dfb749d"
Unverified
Commit
cd19b193
authored
Oct 30, 2023
by
Hz, Ji
Committed by
GitHub
Oct 30, 2023
Browse files
make tests of pytorch_example device agnostic (#27081)
parent
6b466771
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
32 additions
and
33 deletions
+32
-33
examples/pytorch/test_accelerate_examples.py
examples/pytorch/test_accelerate_examples.py
+12
-12
examples/pytorch/test_pytorch_examples.py
examples/pytorch/test_pytorch_examples.py
+20
-21
No files found.
examples/pytorch/test_accelerate_examples.py
View file @
cd19b193
...
@@ -24,11 +24,16 @@ import tempfile
...
@@ -24,11 +24,16 @@ import tempfile
import
unittest
import
unittest
from
unittest
import
mock
from
unittest
import
mock
import
torch
from
accelerate.utils
import
write_basic_config
from
accelerate.utils
import
write_basic_config
from
transformers.testing_utils
import
TestCasePlus
,
get_gpu_count
,
run_command
,
slow
,
torch_device
from
transformers.testing_utils
import
(
from
transformers.utils
import
is_apex_available
TestCasePlus
,
backend_device_count
,
is_torch_fp16_available_on_device
,
run_command
,
slow
,
torch_device
,
)
logging
.
basicConfig
(
level
=
logging
.
DEBUG
)
logging
.
basicConfig
(
level
=
logging
.
DEBUG
)
...
@@ -54,11 +59,6 @@ def get_results(output_dir):
...
@@ -54,11 +59,6 @@ def get_results(output_dir):
return
results
return
results
def
is_cuda_and_apex_available
():
is_using_cuda
=
torch
.
cuda
.
is_available
()
and
torch_device
==
"cuda"
return
is_using_cuda
and
is_apex_available
()
stream_handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
stream_handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
logger
.
addHandler
(
stream_handler
)
logger
.
addHandler
(
stream_handler
)
...
@@ -93,7 +93,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
...
@@ -93,7 +93,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
--with_tracking
--with_tracking
"""
.
split
()
"""
.
split
()
if
is_
cuda_and_apex_available
(
):
if
is_
torch_fp16_available_on_device
(
torch_device
):
testargs
.
append
(
"--fp16"
)
testargs
.
append
(
"--fp16"
)
run_command
(
self
.
_launch_args
+
testargs
)
run_command
(
self
.
_launch_args
+
testargs
)
...
@@ -119,7 +119,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
...
@@ -119,7 +119,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
--with_tracking
--with_tracking
"""
.
split
()
"""
.
split
()
if
torch
.
cuda
.
device_count
()
>
1
:
if
backend_
device_count
(
torch_device
)
>
1
:
# Skipping because there are not enough batches to train the model + would need a drop_last to work.
# Skipping because there are not enough batches to train the model + would need a drop_last to work.
return
return
...
@@ -152,7 +152,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
...
@@ -152,7 +152,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
@
mock
.
patch
.
dict
(
os
.
environ
,
{
"WANDB_MODE"
:
"offline"
})
@
mock
.
patch
.
dict
(
os
.
environ
,
{
"WANDB_MODE"
:
"offline"
})
def
test_run_ner_no_trainer
(
self
):
def
test_run_ner_no_trainer
(
self
):
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
epochs
=
7
if
get_gpu_count
(
)
>
1
else
2
epochs
=
7
if
backend_device_count
(
torch_device
)
>
1
else
2
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()
testargs
=
f
"""
testargs
=
f
"""
...
@@ -326,7 +326,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
...
@@ -326,7 +326,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
--checkpointing_steps 1
--checkpointing_steps 1
"""
.
split
()
"""
.
split
()
if
is_
cuda_and_apex_available
(
):
if
is_
torch_fp16_available_on_device
(
torch_device
):
testargs
.
append
(
"--fp16"
)
testargs
.
append
(
"--fp16"
)
run_command
(
self
.
_launch_args
+
testargs
)
run_command
(
self
.
_launch_args
+
testargs
)
...
...
examples/pytorch/test_pytorch_examples.py
View file @
cd19b193
...
@@ -20,11 +20,15 @@ import os
...
@@ -20,11 +20,15 @@ import os
import
sys
import
sys
from
unittest.mock
import
patch
from
unittest.mock
import
patch
import
torch
from
transformers
import
ViTMAEForPreTraining
,
Wav2Vec2ForPreTraining
from
transformers
import
ViTMAEForPreTraining
,
Wav2Vec2ForPreTraining
from
transformers.testing_utils
import
CaptureLogger
,
TestCasePlus
,
get_gpu_count
,
slow
,
torch_device
from
transformers.testing_utils
import
(
from
transformers.utils
import
is_apex_available
CaptureLogger
,
TestCasePlus
,
backend_device_count
,
is_torch_fp16_available_on_device
,
slow
,
torch_device
,
)
SRC_DIRS
=
[
SRC_DIRS
=
[
...
@@ -86,11 +90,6 @@ def get_results(output_dir):
...
@@ -86,11 +90,6 @@ def get_results(output_dir):
return
results
return
results
def
is_cuda_and_apex_available
():
is_using_cuda
=
torch
.
cuda
.
is_available
()
and
torch_device
==
"cuda"
return
is_using_cuda
and
is_apex_available
()
stream_handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
stream_handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
logger
.
addHandler
(
stream_handler
)
logger
.
addHandler
(
stream_handler
)
...
@@ -116,7 +115,7 @@ class ExamplesTests(TestCasePlus):
...
@@ -116,7 +115,7 @@ class ExamplesTests(TestCasePlus):
--max_seq_length=128
--max_seq_length=128
"""
.
split
()
"""
.
split
()
if
is_
cuda_and_apex_available
(
):
if
is_
torch_fp16_available_on_device
(
torch_device
):
testargs
.
append
(
"--fp16"
)
testargs
.
append
(
"--fp16"
)
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
...
@@ -141,7 +140,7 @@ class ExamplesTests(TestCasePlus):
...
@@ -141,7 +140,7 @@ class ExamplesTests(TestCasePlus):
--overwrite_output_dir
--overwrite_output_dir
"""
.
split
()
"""
.
split
()
if
torch
.
cuda
.
device_count
()
>
1
:
if
backend_
device_count
(
torch_device
)
>
1
:
# Skipping because there are not enough batches to train the model + would need a drop_last to work.
# Skipping because there are not enough batches to train the model + would need a drop_last to work.
return
return
...
@@ -203,7 +202,7 @@ class ExamplesTests(TestCasePlus):
...
@@ -203,7 +202,7 @@ class ExamplesTests(TestCasePlus):
def
test_run_ner
(
self
):
def
test_run_ner
(
self
):
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
epochs
=
7
if
get_gpu_count
(
)
>
1
else
2
epochs
=
7
if
backend_device_count
(
torch_device
)
>
1
else
2
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()
tmp_dir
=
self
.
get_auto_remove_tmp_dir
()
testargs
=
f
"""
testargs
=
f
"""
...
@@ -312,7 +311,7 @@ class ExamplesTests(TestCasePlus):
...
@@ -312,7 +311,7 @@ class ExamplesTests(TestCasePlus):
def
test_generation
(
self
):
def
test_generation
(
self
):
testargs
=
[
"run_generation.py"
,
"--prompt=Hello"
,
"--length=10"
,
"--seed=42"
]
testargs
=
[
"run_generation.py"
,
"--prompt=Hello"
,
"--length=10"
,
"--seed=42"
]
if
is_
cuda_and_apex_available
(
):
if
is_
torch_fp16_available_on_device
(
torch_device
):
testargs
.
append
(
"--fp16"
)
testargs
.
append
(
"--fp16"
)
model_type
,
model_name
=
(
model_type
,
model_name
=
(
...
@@ -401,7 +400,7 @@ class ExamplesTests(TestCasePlus):
...
@@ -401,7 +400,7 @@ class ExamplesTests(TestCasePlus):
--seed 42
--seed 42
"""
.
split
()
"""
.
split
()
if
is_
cuda_and_apex_available
(
):
if
is_
torch_fp16_available_on_device
(
torch_device
):
testargs
.
append
(
"--fp16"
)
testargs
.
append
(
"--fp16"
)
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
...
@@ -431,7 +430,7 @@ class ExamplesTests(TestCasePlus):
...
@@ -431,7 +430,7 @@ class ExamplesTests(TestCasePlus):
--seed 42
--seed 42
"""
.
split
()
"""
.
split
()
if
is_
cuda_and_apex_available
(
):
if
is_
torch_fp16_available_on_device
(
torch_device
):
testargs
.
append
(
"--fp16"
)
testargs
.
append
(
"--fp16"
)
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
...
@@ -462,7 +461,7 @@ class ExamplesTests(TestCasePlus):
...
@@ -462,7 +461,7 @@ class ExamplesTests(TestCasePlus):
--seed 42
--seed 42
"""
.
split
()
"""
.
split
()
if
is_
cuda_and_apex_available
(
):
if
is_
torch_fp16_available_on_device
(
torch_device
):
testargs
.
append
(
"--fp16"
)
testargs
.
append
(
"--fp16"
)
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
...
@@ -493,7 +492,7 @@ class ExamplesTests(TestCasePlus):
...
@@ -493,7 +492,7 @@ class ExamplesTests(TestCasePlus):
--seed 42
--seed 42
"""
.
split
()
"""
.
split
()
if
is_
cuda_and_apex_available
(
):
if
is_
torch_fp16_available_on_device
(
torch_device
):
testargs
.
append
(
"--fp16"
)
testargs
.
append
(
"--fp16"
)
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
...
@@ -525,7 +524,7 @@ class ExamplesTests(TestCasePlus):
...
@@ -525,7 +524,7 @@ class ExamplesTests(TestCasePlus):
--seed 42
--seed 42
"""
.
split
()
"""
.
split
()
if
is_
cuda_and_apex_available
(
):
if
is_
torch_fp16_available_on_device
(
torch_device
):
testargs
.
append
(
"--fp16"
)
testargs
.
append
(
"--fp16"
)
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
...
@@ -551,7 +550,7 @@ class ExamplesTests(TestCasePlus):
...
@@ -551,7 +550,7 @@ class ExamplesTests(TestCasePlus):
--seed 42
--seed 42
"""
.
split
()
"""
.
split
()
if
is_
cuda_and_apex_available
(
):
if
is_
torch_fp16_available_on_device
(
torch_device
):
testargs
.
append
(
"--fp16"
)
testargs
.
append
(
"--fp16"
)
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
...
@@ -579,7 +578,7 @@ class ExamplesTests(TestCasePlus):
...
@@ -579,7 +578,7 @@ class ExamplesTests(TestCasePlus):
--seed 42
--seed 42
"""
.
split
()
"""
.
split
()
if
is_
cuda_and_apex_available
(
):
if
is_
torch_fp16_available_on_device
(
torch_device
):
testargs
.
append
(
"--fp16"
)
testargs
.
append
(
"--fp16"
)
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
...
@@ -604,7 +603,7 @@ class ExamplesTests(TestCasePlus):
...
@@ -604,7 +603,7 @@ class ExamplesTests(TestCasePlus):
--seed 32
--seed 32
"""
.
split
()
"""
.
split
()
if
is_
cuda_and_apex_available
(
):
if
is_
torch_fp16_available_on_device
(
torch_device
):
testargs
.
append
(
"--fp16"
)
testargs
.
append
(
"--fp16"
)
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
with
patch
.
object
(
sys
,
"argv"
,
testargs
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment