Unverified Commit 66e6a021 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

[CI] Remove unittest dependency from `testing_utils.py` (#12621)



* update

* Update tests/testing_utils.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update tests/testing_utils.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Apply style fixes

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 5a47442f
......@@ -13,7 +13,6 @@ import struct
import sys
import tempfile
import time
import unittest
import urllib.parse
from collections import UserDict
from contextlib import contextmanager
......@@ -24,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tupl
import numpy as np
import PIL.Image
import PIL.ImageOps
import pytest
import requests
from numpy.linalg import norm
from packaging import version
......@@ -267,7 +267,7 @@ def slow(test_case):
Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
"""
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case)
return pytest.mark.skipif(not _run_slow_tests, reason="test is slow")(test_case)
def nightly(test_case):
......@@ -277,7 +277,7 @@ def nightly(test_case):
Slow tests are skipped by default. Set the RUN_NIGHTLY environment variable to a truthy value to run them.
"""
return unittest.skipUnless(_run_nightly_tests, "test is nightly")(test_case)
return pytest.mark.skipif(not _run_nightly_tests, reason="test is nightly")(test_case)
def is_torch_compile(test_case):
......@@ -287,23 +287,23 @@ def is_torch_compile(test_case):
Compile tests are skipped by default. Set the RUN_COMPILE environment variable to a truthy value to run them.
"""
return unittest.skipUnless(_run_compile_tests, "test is torch compile")(test_case)
return pytest.mark.skipif(not _run_compile_tests, reason="test is torch compile")(test_case)
def require_torch(test_case):
"""
Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed.
"""
return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
return pytest.mark.skipif(not is_torch_available(), reason="test requires PyTorch")(test_case)
def require_torch_2(test_case):
"""
Decorator marking a test that requires PyTorch 2. These tests are skipped when it isn't installed.
"""
return unittest.skipUnless(is_torch_available() and is_torch_version(">=", "2.0.0"), "test requires PyTorch 2")(
test_case
)
return pytest.mark.skipif(
not (is_torch_available() and is_torch_version(">=", "2.0.0")), reason="test requires PyTorch 2"
)(test_case)
def require_torch_version_greater_equal(torch_version):
......@@ -311,8 +311,9 @@ def require_torch_version_greater_equal(torch_version):
def decorator(test_case):
correct_torch_version = is_torch_available() and is_torch_version(">=", torch_version)
return unittest.skipUnless(
correct_torch_version, f"test requires torch with the version greater than or equal to {torch_version}"
return pytest.mark.skipif(
not correct_torch_version,
reason=f"test requires torch with the version greater than or equal to {torch_version}",
)(test_case)
return decorator
......@@ -323,8 +324,8 @@ def require_torch_version_greater(torch_version):
def decorator(test_case):
correct_torch_version = is_torch_available() and is_torch_version(">", torch_version)
return unittest.skipUnless(
correct_torch_version, f"test requires torch with the version greater than {torch_version}"
return pytest.mark.skipif(
not correct_torch_version, reason=f"test requires torch with the version greater than {torch_version}"
)(test_case)
return decorator
......@@ -332,19 +333,18 @@ def require_torch_version_greater(torch_version):
def require_torch_gpu(test_case):
"""Decorator marking a test that requires CUDA and PyTorch."""
return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")(
test_case
)
return pytest.mark.skipif(torch_device != "cuda", reason="test requires PyTorch+CUDA")(test_case)
def require_torch_cuda_compatibility(expected_compute_capability):
def decorator(test_case):
if torch.cuda.is_available():
current_compute_capability = get_torch_cuda_device_capability()
return unittest.skipUnless(
float(current_compute_capability) == float(expected_compute_capability),
"Test not supported for this compute capability.",
)
return pytest.mark.skipif(
float(current_compute_capability) != float(expected_compute_capability),
reason="Test not supported for this compute capability.",
)(test_case)
return test_case
return decorator
......@@ -352,9 +352,7 @@ def require_torch_cuda_compatibility(expected_compute_capability):
# These decorators are for accelerator-specific behaviours that are not GPU-specific
def require_torch_accelerator(test_case):
"""Decorator marking a test that requires an accelerator backend and PyTorch."""
return unittest.skipUnless(is_torch_available() and torch_device != "cpu", "test requires accelerator+PyTorch")(
test_case
)
return pytest.mark.skipif(torch_device == "cpu", reason="test requires accelerator+PyTorch")(test_case)
def require_torch_multi_gpu(test_case):
......@@ -364,11 +362,11 @@ def require_torch_multi_gpu(test_case):
-k "multi_gpu"
"""
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
return pytest.mark.skip(reason="test requires PyTorch")(test_case)
import torch
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
return pytest.mark.skipif(torch.cuda.device_count() <= 1, reason="test requires multiple GPUs")(test_case)
def require_torch_multi_accelerator(test_case):
......@@ -377,27 +375,28 @@ def require_torch_multi_accelerator(test_case):
without multiple hardware accelerators.
"""
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
return pytest.mark.skip(reason="test requires PyTorch")(test_case)
import torch
return unittest.skipUnless(
torch.cuda.device_count() > 1 or torch.xpu.device_count() > 1, "test requires multiple hardware accelerators"
return pytest.mark.skipif(
not (torch.cuda.device_count() > 1 or torch.xpu.device_count() > 1),
reason="test requires multiple hardware accelerators",
)(test_case)
def require_torch_accelerator_with_fp16(test_case):
"""Decorator marking a test that requires an accelerator with support for the FP16 data type."""
return unittest.skipUnless(_is_torch_fp16_available(torch_device), "test requires accelerator with fp16 support")(
test_case
)
return pytest.mark.skipif(
not _is_torch_fp16_available(torch_device), reason="test requires accelerator with fp16 support"
)(test_case)
def require_torch_accelerator_with_fp64(test_case):
"""Decorator marking a test that requires an accelerator with support for the FP64 data type."""
return unittest.skipUnless(_is_torch_fp64_available(torch_device), "test requires accelerator with fp64 support")(
test_case
)
return pytest.mark.skipif(
not _is_torch_fp64_available(torch_device), reason="test requires accelerator with fp64 support"
)(test_case)
def require_big_gpu_with_torch_cuda(test_case):
......@@ -406,17 +405,17 @@ def require_big_gpu_with_torch_cuda(test_case):
etc.
"""
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
return pytest.mark.skip(reason="test requires PyTorch")(test_case)
import torch
if not torch.cuda.is_available():
return unittest.skip("test requires PyTorch CUDA")(test_case)
return pytest.mark.skip(reason="test requires PyTorch CUDA")(test_case)
device_properties = torch.cuda.get_device_properties(0)
total_memory = device_properties.total_memory / (1024**3)
return unittest.skipUnless(
total_memory >= BIG_GPU_MEMORY, f"test requires a GPU with at least {BIG_GPU_MEMORY} GB memory"
return pytest.mark.skipif(
total_memory < BIG_GPU_MEMORY, reason=f"test requires a GPU with at least {BIG_GPU_MEMORY} GB memory"
)(test_case)
......@@ -430,12 +429,12 @@ def require_big_accelerator(test_case):
test_case = pytest.mark.big_accelerator(test_case)
if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case)
return pytest.mark.skip(reason="test requires PyTorch")(test_case)
import torch
if not (torch.cuda.is_available() or torch.xpu.is_available()):
return unittest.skip("test requires PyTorch CUDA")(test_case)
return pytest.mark.skip(reason="test requires PyTorch CUDA")(test_case)
if torch.xpu.is_available():
device_properties = torch.xpu.get_device_properties(0)
......@@ -443,30 +442,30 @@ def require_big_accelerator(test_case):
device_properties = torch.cuda.get_device_properties(0)
total_memory = device_properties.total_memory / (1024**3)
return unittest.skipUnless(
total_memory >= BIG_GPU_MEMORY,
f"test requires a hardware accelerator with at least {BIG_GPU_MEMORY} GB memory",
return pytest.mark.skipif(
total_memory < BIG_GPU_MEMORY,
reason=f"test requires a hardware accelerator with at least {BIG_GPU_MEMORY} GB memory",
)(test_case)
def require_torch_accelerator_with_training(test_case):
"""Decorator marking a test that requires an accelerator with support for training."""
return unittest.skipUnless(
is_torch_available() and backend_supports_training(torch_device),
"test requires accelerator with training support",
return pytest.mark.skipif(
not (is_torch_available() and backend_supports_training(torch_device)),
reason="test requires accelerator with training support",
)(test_case)
def skip_mps(test_case):
"""Decorator marking a test to skip if torch_device is 'mps'"""
return unittest.skipUnless(torch_device != "mps", "test requires non 'mps' device")(test_case)
return pytest.mark.skipif(torch_device == "mps", reason="test requires non 'mps' device")(test_case)
def require_flax(test_case):
"""
Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed
"""
return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case)
return pytest.mark.skipif(not is_flax_available(), reason="test requires JAX & Flax")(test_case)
def require_compel(test_case):
......@@ -474,21 +473,21 @@ def require_compel(test_case):
Decorator marking a test that requires compel: https://github.com/damian0815/compel. These tests are skipped when
the library is not installed.
"""
return unittest.skipUnless(is_compel_available(), "test requires compel")(test_case)
return pytest.mark.skipif(not is_compel_available(), reason="test requires compel")(test_case)
def require_onnxruntime(test_case):
"""
Decorator marking a test that requires onnxruntime. These tests are skipped when onnxruntime isn't installed.
"""
return unittest.skipUnless(is_onnx_available(), "test requires onnxruntime")(test_case)
return pytest.mark.skipif(not is_onnx_available(), reason="test requires onnxruntime")(test_case)
def require_note_seq(test_case):
"""
Decorator marking a test that requires note_seq. These tests are skipped when note_seq isn't installed.
"""
return unittest.skipUnless(is_note_seq_available(), "test requires note_seq")(test_case)
return pytest.mark.skipif(not is_note_seq_available(), reason="test requires note_seq")(test_case)
def require_accelerator(test_case):
......@@ -496,14 +495,14 @@ def require_accelerator(test_case):
Decorator marking a test that requires a hardware accelerator backend. These tests are skipped when there are no
hardware accelerator available.
"""
return unittest.skipUnless(torch_device != "cpu", "test requires a hardware accelerator")(test_case)
return pytest.mark.skipif(torch_device == "cpu", reason="test requires a hardware accelerator")(test_case)
def require_torchsde(test_case):
"""
Decorator marking a test that requires torchsde. These tests are skipped when torchsde isn't installed.
"""
return unittest.skipUnless(is_torchsde_available(), "test requires torchsde")(test_case)
return pytest.mark.skipif(not is_torchsde_available(), reason="test requires torchsde")(test_case)
def require_peft_backend(test_case):
......@@ -511,35 +510,35 @@ def require_peft_backend(test_case):
Decorator marking a test that requires PEFT backend, this would require some specific versions of PEFT and
transformers.
"""
return unittest.skipUnless(USE_PEFT_BACKEND, "test requires PEFT backend")(test_case)
return pytest.mark.skipif(not USE_PEFT_BACKEND, reason="test requires PEFT backend")(test_case)
def require_timm(test_case):
"""
Decorator marking a test that requires timm. These tests are skipped when timm isn't installed.
"""
return unittest.skipUnless(is_timm_available(), "test requires timm")(test_case)
return pytest.mark.skipif(not is_timm_available(), reason="test requires timm")(test_case)
def require_bitsandbytes(test_case):
"""
Decorator marking a test that requires bitsandbytes. These tests are skipped when bitsandbytes isn't installed.
"""
return unittest.skipUnless(is_bitsandbytes_available(), "test requires bitsandbytes")(test_case)
return pytest.mark.skipif(not is_bitsandbytes_available(), reason="test requires bitsandbytes")(test_case)
def require_quanto(test_case):
"""
Decorator marking a test that requires quanto. These tests are skipped when quanto isn't installed.
"""
return unittest.skipUnless(is_optimum_quanto_available(), "test requires quanto")(test_case)
return pytest.mark.skipif(not is_optimum_quanto_available(), reason="test requires quanto")(test_case)
def require_accelerate(test_case):
"""
Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
"""
return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case)
return pytest.mark.skipif(not is_accelerate_available(), reason="test requires accelerate")(test_case)
def require_peft_version_greater(peft_version):
......@@ -552,8 +551,8 @@ def require_peft_version_greater(peft_version):
correct_peft_version = is_peft_available() and version.parse(
version.parse(importlib.metadata.version("peft")).base_version
) > version.parse(peft_version)
return unittest.skipUnless(
correct_peft_version, f"test requires PEFT backend with the version greater than {peft_version}"
return pytest.mark.skipif(
not correct_peft_version, reason=f"test requires PEFT backend with the version greater than {peft_version}"
)(test_case)
return decorator
......@@ -569,9 +568,9 @@ def require_transformers_version_greater(transformers_version):
correct_transformers_version = is_transformers_available() and version.parse(
version.parse(importlib.metadata.version("transformers")).base_version
) > version.parse(transformers_version)
return unittest.skipUnless(
correct_transformers_version,
f"test requires transformers with the version greater than {transformers_version}",
return pytest.mark.skipif(
not correct_transformers_version,
reason=f"test requires transformers with the version greater than {transformers_version}",
)(test_case)
return decorator
......@@ -582,8 +581,9 @@ def require_accelerate_version_greater(accelerate_version):
correct_accelerate_version = is_accelerate_available() and version.parse(
version.parse(importlib.metadata.version("accelerate")).base_version
) > version.parse(accelerate_version)
return unittest.skipUnless(
correct_accelerate_version, f"Test requires accelerate with the version greater than {accelerate_version}."
return pytest.mark.skipif(
not correct_accelerate_version,
reason=f"Test requires accelerate with the version greater than {accelerate_version}.",
)(test_case)
return decorator
......@@ -594,8 +594,8 @@ def require_bitsandbytes_version_greater(bnb_version):
correct_bnb_version = is_bitsandbytes_available() and version.parse(
version.parse(importlib.metadata.version("bitsandbytes")).base_version
) > version.parse(bnb_version)
return unittest.skipUnless(
correct_bnb_version, f"Test requires bitsandbytes with the version greater than {bnb_version}."
return pytest.mark.skipif(
not correct_bnb_version, reason=f"Test requires bitsandbytes with the version greater than {bnb_version}."
)(test_case)
return decorator
......@@ -606,8 +606,9 @@ def require_hf_hub_version_greater(hf_hub_version):
correct_hf_hub_version = version.parse(
version.parse(importlib.metadata.version("huggingface_hub")).base_version
) > version.parse(hf_hub_version)
return unittest.skipUnless(
correct_hf_hub_version, f"Test requires huggingface_hub with the version greater than {hf_hub_version}."
return pytest.mark.skipif(
not correct_hf_hub_version,
reason=f"Test requires huggingface_hub with the version greater than {hf_hub_version}.",
)(test_case)
return decorator
......@@ -618,8 +619,8 @@ def require_gguf_version_greater_or_equal(gguf_version):
correct_gguf_version = is_gguf_available() and version.parse(
version.parse(importlib.metadata.version("gguf")).base_version
) >= version.parse(gguf_version)
return unittest.skipUnless(
correct_gguf_version, f"Test requires gguf with the version greater than {gguf_version}."
return pytest.mark.skipif(
not correct_gguf_version, reason=f"Test requires gguf with the version greater than {gguf_version}."
)(test_case)
return decorator
......@@ -630,8 +631,8 @@ def require_torchao_version_greater_or_equal(torchao_version):
correct_torchao_version = is_torchao_available() and version.parse(
version.parse(importlib.metadata.version("torchao")).base_version
) >= version.parse(torchao_version)
return unittest.skipUnless(
correct_torchao_version, f"Test requires torchao with version greater than {torchao_version}."
return pytest.mark.skipif(
not correct_torchao_version, reason=f"Test requires torchao with version greater than {torchao_version}."
)(test_case)
return decorator
......@@ -642,8 +643,8 @@ def require_kernels_version_greater_or_equal(kernels_version):
correct_kernels_version = is_kernels_available() and version.parse(
version.parse(importlib.metadata.version("kernels")).base_version
) >= version.parse(kernels_version)
return unittest.skipUnless(
correct_kernels_version, f"Test requires kernels with version greater than {kernels_version}."
return pytest.mark.skipif(
not correct_kernels_version, reason=f"Test requires kernels with version greater than {kernels_version}."
)(test_case)
return decorator
......@@ -653,7 +654,7 @@ def deprecate_after_peft_backend(test_case):
"""
Decorator marking a test that will be skipped after PEFT backend
"""
return unittest.skipUnless(not USE_PEFT_BACKEND, "test skipped in favor of PEFT backend")(test_case)
return pytest.mark.skipif(USE_PEFT_BACKEND, reason="test skipped in favor of PEFT backend")(test_case)
def get_python_version():
......@@ -1064,8 +1065,8 @@ def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
To run a test in a subprocess. In particular, this can avoid (GPU) memory issue.
Args:
test_case (`unittest.TestCase`):
The test that will run `target_func`.
test_case:
The test case object that will run `target_func`.
target_func (`Callable`):
The function implementing the actual testing logic.
inputs (`dict`, *optional*, defaults to `None`):
......@@ -1083,7 +1084,7 @@ def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
input_queue = ctx.Queue(1)
output_queue = ctx.JoinableQueue(1)
# We can't send `unittest.TestCase` to the child, otherwise we get issues regarding pickle.
# We can't send test case objects to the child, otherwise we get issues regarding pickle.
input_queue.put(inputs, timeout=timeout)
process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment