Unverified Commit 1791ef8d authored by Alex McKinney's avatar Alex McKinney Committed by GitHub
Browse files

Adds `TRANSFORMERS_TEST_DEVICE` (#25506)

* Adds `TRANSFORMERS_TEST_DEVICE`
Mirrors the same API in the diffusers library. Useful in transformers
too.

* replace backend checking with trying `torch.device`

* Adds better error message for unknown test devices

* `make style`

* adds documentation showing `TRANSFORMERS_TEST_DEVICE` usage.
parent e7e9261a
...@@ -511,6 +511,16 @@ from transformers.testing_utils import get_gpu_count ...@@ -511,6 +511,16 @@ from transformers.testing_utils import get_gpu_count
n_gpu = get_gpu_count() # works with torch and tf n_gpu = get_gpu_count() # works with torch and tf
``` ```
### Testing with a specific PyTorch backend
To run the test suite on a specific torch backend add `TRANSFORMERS_TEST_DEVICE="$device"` where `$device` is the target backend. For example, to test on CPU only:
```bash
TRANSFORMERS_TEST_DEVICE="cpu" pytest tests/test_logging.py
```
This variable is useful for testing custom or less common PyTorch backends such as `mps`. It can also be used to achieve the same effect as `CUDA_VISIBLE_DEVICES` by targeting specific GPUs or testing in CPU-only mode.
### Distributed training ### Distributed training
`pytest` can't deal with distributed training directly. If this is attempted - the sub-processes don't do the right `pytest` can't deal with distributed training directly. If this is attempted - the sub-processes don't do the right
......
...@@ -614,7 +614,16 @@ if is_torch_available(): ...@@ -614,7 +614,16 @@ if is_torch_available():
# Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode # Set env var CUDA_VISIBLE_DEVICES="" to force cpu-mode
import torch import torch
if torch.cuda.is_available(): if "TRANSFORMERS_TEST_DEVICE" in os.environ:
torch_device = os.environ["TRANSFORMERS_TEST_DEVICE"]
try:
# try creating device to see if provided device is valid
_ = torch.device(torch_device)
except RuntimeError as e:
raise RuntimeError(
f"Unknown testing device specified by environment variable `TRANSFORMERS_TEST_DEVICE`: {torch_device}"
) from e
elif torch.cuda.is_available():
torch_device = "cuda" torch_device = "cuda"
elif _run_third_party_device_tests and is_torch_npu_available(): elif _run_third_party_device_tests and is_torch_npu_available():
torch_device = "npu" torch_device = "npu"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment