"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "aa90e0c36adc0034ece203c857d0d993c82ae65a"
Unverified Commit cd19b193 authored by Hz, Ji's avatar Hz, Ji Committed by GitHub
Browse files

make tests of pytorch_example device agnostic (#27081)

parent 6b466771
...@@ -24,11 +24,16 @@ import tempfile ...@@ -24,11 +24,16 @@ import tempfile
import unittest import unittest
from unittest import mock from unittest import mock
import torch
from accelerate.utils import write_basic_config from accelerate.utils import write_basic_config
from transformers.testing_utils import TestCasePlus, get_gpu_count, run_command, slow, torch_device from transformers.testing_utils import (
from transformers.utils import is_apex_available TestCasePlus,
backend_device_count,
is_torch_fp16_available_on_device,
run_command,
slow,
torch_device,
)
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
...@@ -54,11 +59,6 @@ def get_results(output_dir): ...@@ -54,11 +59,6 @@ def get_results(output_dir):
return results return results
def is_cuda_and_apex_available():
is_using_cuda = torch.cuda.is_available() and torch_device == "cuda"
return is_using_cuda and is_apex_available()
stream_handler = logging.StreamHandler(sys.stdout) stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
...@@ -93,7 +93,7 @@ class ExamplesTestsNoTrainer(TestCasePlus): ...@@ -93,7 +93,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
--with_tracking --with_tracking
""".split() """.split()
if is_cuda_and_apex_available(): if is_torch_fp16_available_on_device(torch_device):
testargs.append("--fp16") testargs.append("--fp16")
run_command(self._launch_args + testargs) run_command(self._launch_args + testargs)
...@@ -119,7 +119,7 @@ class ExamplesTestsNoTrainer(TestCasePlus): ...@@ -119,7 +119,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
--with_tracking --with_tracking
""".split() """.split()
if torch.cuda.device_count() > 1: if backend_device_count(torch_device) > 1:
# Skipping because there are not enough batches to train the model + would need a drop_last to work. # Skipping because there are not enough batches to train the model + would need a drop_last to work.
return return
...@@ -152,7 +152,7 @@ class ExamplesTestsNoTrainer(TestCasePlus): ...@@ -152,7 +152,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
@mock.patch.dict(os.environ, {"WANDB_MODE": "offline"}) @mock.patch.dict(os.environ, {"WANDB_MODE": "offline"})
def test_run_ner_no_trainer(self): def test_run_ner_no_trainer(self):
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu # with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
epochs = 7 if get_gpu_count() > 1 else 2 epochs = 7 if backend_device_count(torch_device) > 1 else 2
tmp_dir = self.get_auto_remove_tmp_dir() tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f""" testargs = f"""
...@@ -326,7 +326,7 @@ class ExamplesTestsNoTrainer(TestCasePlus): ...@@ -326,7 +326,7 @@ class ExamplesTestsNoTrainer(TestCasePlus):
--checkpointing_steps 1 --checkpointing_steps 1
""".split() """.split()
if is_cuda_and_apex_available(): if is_torch_fp16_available_on_device(torch_device):
testargs.append("--fp16") testargs.append("--fp16")
run_command(self._launch_args + testargs) run_command(self._launch_args + testargs)
......
...@@ -20,11 +20,15 @@ import os ...@@ -20,11 +20,15 @@ import os
import sys import sys
from unittest.mock import patch from unittest.mock import patch
import torch
from transformers import ViTMAEForPreTraining, Wav2Vec2ForPreTraining from transformers import ViTMAEForPreTraining, Wav2Vec2ForPreTraining
from transformers.testing_utils import CaptureLogger, TestCasePlus, get_gpu_count, slow, torch_device from transformers.testing_utils import (
from transformers.utils import is_apex_available CaptureLogger,
TestCasePlus,
backend_device_count,
is_torch_fp16_available_on_device,
slow,
torch_device,
)
SRC_DIRS = [ SRC_DIRS = [
...@@ -86,11 +90,6 @@ def get_results(output_dir): ...@@ -86,11 +90,6 @@ def get_results(output_dir):
return results return results
def is_cuda_and_apex_available():
is_using_cuda = torch.cuda.is_available() and torch_device == "cuda"
return is_using_cuda and is_apex_available()
stream_handler = logging.StreamHandler(sys.stdout) stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler) logger.addHandler(stream_handler)
...@@ -116,7 +115,7 @@ class ExamplesTests(TestCasePlus): ...@@ -116,7 +115,7 @@ class ExamplesTests(TestCasePlus):
--max_seq_length=128 --max_seq_length=128
""".split() """.split()
if is_cuda_and_apex_available(): if is_torch_fp16_available_on_device(torch_device):
testargs.append("--fp16") testargs.append("--fp16")
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
...@@ -141,7 +140,7 @@ class ExamplesTests(TestCasePlus): ...@@ -141,7 +140,7 @@ class ExamplesTests(TestCasePlus):
--overwrite_output_dir --overwrite_output_dir
""".split() """.split()
if torch.cuda.device_count() > 1: if backend_device_count(torch_device) > 1:
# Skipping because there are not enough batches to train the model + would need a drop_last to work. # Skipping because there are not enough batches to train the model + would need a drop_last to work.
return return
...@@ -203,7 +202,7 @@ class ExamplesTests(TestCasePlus): ...@@ -203,7 +202,7 @@ class ExamplesTests(TestCasePlus):
def test_run_ner(self): def test_run_ner(self):
# with so little data distributed training needs more epochs to get the score on par with 0/1 gpu # with so little data distributed training needs more epochs to get the score on par with 0/1 gpu
epochs = 7 if get_gpu_count() > 1 else 2 epochs = 7 if backend_device_count(torch_device) > 1 else 2
tmp_dir = self.get_auto_remove_tmp_dir() tmp_dir = self.get_auto_remove_tmp_dir()
testargs = f""" testargs = f"""
...@@ -312,7 +311,7 @@ class ExamplesTests(TestCasePlus): ...@@ -312,7 +311,7 @@ class ExamplesTests(TestCasePlus):
def test_generation(self): def test_generation(self):
testargs = ["run_generation.py", "--prompt=Hello", "--length=10", "--seed=42"] testargs = ["run_generation.py", "--prompt=Hello", "--length=10", "--seed=42"]
if is_cuda_and_apex_available(): if is_torch_fp16_available_on_device(torch_device):
testargs.append("--fp16") testargs.append("--fp16")
model_type, model_name = ( model_type, model_name = (
...@@ -401,7 +400,7 @@ class ExamplesTests(TestCasePlus): ...@@ -401,7 +400,7 @@ class ExamplesTests(TestCasePlus):
--seed 42 --seed 42
""".split() """.split()
if is_cuda_and_apex_available(): if is_torch_fp16_available_on_device(torch_device):
testargs.append("--fp16") testargs.append("--fp16")
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
...@@ -431,7 +430,7 @@ class ExamplesTests(TestCasePlus): ...@@ -431,7 +430,7 @@ class ExamplesTests(TestCasePlus):
--seed 42 --seed 42
""".split() """.split()
if is_cuda_and_apex_available(): if is_torch_fp16_available_on_device(torch_device):
testargs.append("--fp16") testargs.append("--fp16")
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
...@@ -462,7 +461,7 @@ class ExamplesTests(TestCasePlus): ...@@ -462,7 +461,7 @@ class ExamplesTests(TestCasePlus):
--seed 42 --seed 42
""".split() """.split()
if is_cuda_and_apex_available(): if is_torch_fp16_available_on_device(torch_device):
testargs.append("--fp16") testargs.append("--fp16")
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
...@@ -493,7 +492,7 @@ class ExamplesTests(TestCasePlus): ...@@ -493,7 +492,7 @@ class ExamplesTests(TestCasePlus):
--seed 42 --seed 42
""".split() """.split()
if is_cuda_and_apex_available(): if is_torch_fp16_available_on_device(torch_device):
testargs.append("--fp16") testargs.append("--fp16")
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
...@@ -525,7 +524,7 @@ class ExamplesTests(TestCasePlus): ...@@ -525,7 +524,7 @@ class ExamplesTests(TestCasePlus):
--seed 42 --seed 42
""".split() """.split()
if is_cuda_and_apex_available(): if is_torch_fp16_available_on_device(torch_device):
testargs.append("--fp16") testargs.append("--fp16")
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
...@@ -551,7 +550,7 @@ class ExamplesTests(TestCasePlus): ...@@ -551,7 +550,7 @@ class ExamplesTests(TestCasePlus):
--seed 42 --seed 42
""".split() """.split()
if is_cuda_and_apex_available(): if is_torch_fp16_available_on_device(torch_device):
testargs.append("--fp16") testargs.append("--fp16")
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
...@@ -579,7 +578,7 @@ class ExamplesTests(TestCasePlus): ...@@ -579,7 +578,7 @@ class ExamplesTests(TestCasePlus):
--seed 42 --seed 42
""".split() """.split()
if is_cuda_and_apex_available(): if is_torch_fp16_available_on_device(torch_device):
testargs.append("--fp16") testargs.append("--fp16")
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
...@@ -604,7 +603,7 @@ class ExamplesTests(TestCasePlus): ...@@ -604,7 +603,7 @@ class ExamplesTests(TestCasePlus):
--seed 32 --seed 32
""".split() """.split()
if is_cuda_and_apex_available(): if is_torch_fp16_available_on_device(torch_device):
testargs.append("--fp16") testargs.append("--fp16")
with patch.object(sys, "argv", testargs): with patch.object(sys, "argv", testargs):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment