Unverified Commit 66e6a021 authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

[CI] Remove unittest dependency from `testing_utils.py` (#12621)



* update

* Update tests/testing_utils.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Update tests/testing_utils.py
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>

* Apply style fixes

---------
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatargithub-actions[bot] <github-actions[bot]@users.noreply.github.com>
parent 5a47442f
...@@ -13,7 +13,6 @@ import struct ...@@ -13,7 +13,6 @@ import struct
import sys import sys
import tempfile import tempfile
import time import time
import unittest
import urllib.parse import urllib.parse
from collections import UserDict from collections import UserDict
from contextlib import contextmanager from contextlib import contextmanager
...@@ -24,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tupl ...@@ -24,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tupl
import numpy as np import numpy as np
import PIL.Image import PIL.Image
import PIL.ImageOps import PIL.ImageOps
import pytest
import requests import requests
from numpy.linalg import norm from numpy.linalg import norm
from packaging import version from packaging import version
...@@ -267,7 +267,7 @@ def slow(test_case): ...@@ -267,7 +267,7 @@ def slow(test_case):
Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them. Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them.
""" """
return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) return pytest.mark.skipif(not _run_slow_tests, reason="test is slow")(test_case)
def nightly(test_case): def nightly(test_case):
...@@ -277,7 +277,7 @@ def nightly(test_case): ...@@ -277,7 +277,7 @@ def nightly(test_case):
Slow tests are skipped by default. Set the RUN_NIGHTLY environment variable to a truthy value to run them. Slow tests are skipped by default. Set the RUN_NIGHTLY environment variable to a truthy value to run them.
""" """
return unittest.skipUnless(_run_nightly_tests, "test is nightly")(test_case) return pytest.mark.skipif(not _run_nightly_tests, reason="test is nightly")(test_case)
def is_torch_compile(test_case): def is_torch_compile(test_case):
...@@ -287,23 +287,23 @@ def is_torch_compile(test_case): ...@@ -287,23 +287,23 @@ def is_torch_compile(test_case):
Compile tests are skipped by default. Set the RUN_COMPILE environment variable to a truthy value to run them. Compile tests are skipped by default. Set the RUN_COMPILE environment variable to a truthy value to run them.
""" """
return unittest.skipUnless(_run_compile_tests, "test is torch compile")(test_case) return pytest.mark.skipif(not _run_compile_tests, reason="test is torch compile")(test_case)
def require_torch(test_case): def require_torch(test_case):
""" """
Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed. Decorator marking a test that requires PyTorch. These tests are skipped when PyTorch isn't installed.
""" """
return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case) return pytest.mark.skipif(not is_torch_available(), reason="test requires PyTorch")(test_case)
def require_torch_2(test_case): def require_torch_2(test_case):
""" """
Decorator marking a test that requires PyTorch 2. These tests are skipped when it isn't installed. Decorator marking a test that requires PyTorch 2. These tests are skipped when it isn't installed.
""" """
return unittest.skipUnless(is_torch_available() and is_torch_version(">=", "2.0.0"), "test requires PyTorch 2")( return pytest.mark.skipif(
test_case not (is_torch_available() and is_torch_version(">=", "2.0.0")), reason="test requires PyTorch 2"
) )(test_case)
def require_torch_version_greater_equal(torch_version): def require_torch_version_greater_equal(torch_version):
...@@ -311,8 +311,9 @@ def require_torch_version_greater_equal(torch_version): ...@@ -311,8 +311,9 @@ def require_torch_version_greater_equal(torch_version):
def decorator(test_case): def decorator(test_case):
correct_torch_version = is_torch_available() and is_torch_version(">=", torch_version) correct_torch_version = is_torch_available() and is_torch_version(">=", torch_version)
return unittest.skipUnless( return pytest.mark.skipif(
correct_torch_version, f"test requires torch with the version greater than or equal to {torch_version}" not correct_torch_version,
reason=f"test requires torch with the version greater than or equal to {torch_version}",
)(test_case) )(test_case)
return decorator return decorator
...@@ -323,8 +324,8 @@ def require_torch_version_greater(torch_version): ...@@ -323,8 +324,8 @@ def require_torch_version_greater(torch_version):
def decorator(test_case): def decorator(test_case):
correct_torch_version = is_torch_available() and is_torch_version(">", torch_version) correct_torch_version = is_torch_available() and is_torch_version(">", torch_version)
return unittest.skipUnless( return pytest.mark.skipif(
correct_torch_version, f"test requires torch with the version greater than {torch_version}" not correct_torch_version, reason=f"test requires torch with the version greater than {torch_version}"
)(test_case) )(test_case)
return decorator return decorator
...@@ -332,19 +333,18 @@ def require_torch_version_greater(torch_version): ...@@ -332,19 +333,18 @@ def require_torch_version_greater(torch_version):
def require_torch_gpu(test_case): def require_torch_gpu(test_case):
"""Decorator marking a test that requires CUDA and PyTorch.""" """Decorator marking a test that requires CUDA and PyTorch."""
return unittest.skipUnless(is_torch_available() and torch_device == "cuda", "test requires PyTorch+CUDA")( return pytest.mark.skipif(torch_device != "cuda", reason="test requires PyTorch+CUDA")(test_case)
test_case
)
def require_torch_cuda_compatibility(expected_compute_capability): def require_torch_cuda_compatibility(expected_compute_capability):
def decorator(test_case): def decorator(test_case):
if torch.cuda.is_available(): if torch.cuda.is_available():
current_compute_capability = get_torch_cuda_device_capability() current_compute_capability = get_torch_cuda_device_capability()
return unittest.skipUnless( return pytest.mark.skipif(
float(current_compute_capability) == float(expected_compute_capability), float(current_compute_capability) != float(expected_compute_capability),
"Test not supported for this compute capability.", reason="Test not supported for this compute capability.",
) )(test_case)
return test_case
return decorator return decorator
...@@ -352,9 +352,7 @@ def require_torch_cuda_compatibility(expected_compute_capability): ...@@ -352,9 +352,7 @@ def require_torch_cuda_compatibility(expected_compute_capability):
# These decorators are for accelerator-specific behaviours that are not GPU-specific # These decorators are for accelerator-specific behaviours that are not GPU-specific
def require_torch_accelerator(test_case): def require_torch_accelerator(test_case):
"""Decorator marking a test that requires an accelerator backend and PyTorch.""" """Decorator marking a test that requires an accelerator backend and PyTorch."""
return unittest.skipUnless(is_torch_available() and torch_device != "cpu", "test requires accelerator+PyTorch")( return pytest.mark.skipif(torch_device == "cpu", reason="test requires accelerator+PyTorch")(test_case)
test_case
)
def require_torch_multi_gpu(test_case): def require_torch_multi_gpu(test_case):
...@@ -364,11 +362,11 @@ def require_torch_multi_gpu(test_case): ...@@ -364,11 +362,11 @@ def require_torch_multi_gpu(test_case):
-k "multi_gpu" -k "multi_gpu"
""" """
if not is_torch_available(): if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case) return pytest.mark.skip(reason="test requires PyTorch")(test_case)
import torch import torch
return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case) return pytest.mark.skipif(torch.cuda.device_count() <= 1, reason="test requires multiple GPUs")(test_case)
def require_torch_multi_accelerator(test_case): def require_torch_multi_accelerator(test_case):
...@@ -377,27 +375,28 @@ def require_torch_multi_accelerator(test_case): ...@@ -377,27 +375,28 @@ def require_torch_multi_accelerator(test_case):
without multiple hardware accelerators. without multiple hardware accelerators.
""" """
if not is_torch_available(): if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case) return pytest.mark.skip(reason="test requires PyTorch")(test_case)
import torch import torch
return unittest.skipUnless( return pytest.mark.skipif(
torch.cuda.device_count() > 1 or torch.xpu.device_count() > 1, "test requires multiple hardware accelerators" not (torch.cuda.device_count() > 1 or torch.xpu.device_count() > 1),
reason="test requires multiple hardware accelerators",
)(test_case) )(test_case)
def require_torch_accelerator_with_fp16(test_case): def require_torch_accelerator_with_fp16(test_case):
"""Decorator marking a test that requires an accelerator with support for the FP16 data type.""" """Decorator marking a test that requires an accelerator with support for the FP16 data type."""
return unittest.skipUnless(_is_torch_fp16_available(torch_device), "test requires accelerator with fp16 support")( return pytest.mark.skipif(
test_case not _is_torch_fp16_available(torch_device), reason="test requires accelerator with fp16 support"
) )(test_case)
def require_torch_accelerator_with_fp64(test_case): def require_torch_accelerator_with_fp64(test_case):
"""Decorator marking a test that requires an accelerator with support for the FP64 data type.""" """Decorator marking a test that requires an accelerator with support for the FP64 data type."""
return unittest.skipUnless(_is_torch_fp64_available(torch_device), "test requires accelerator with fp64 support")( return pytest.mark.skipif(
test_case not _is_torch_fp64_available(torch_device), reason="test requires accelerator with fp64 support"
) )(test_case)
def require_big_gpu_with_torch_cuda(test_case): def require_big_gpu_with_torch_cuda(test_case):
...@@ -406,17 +405,17 @@ def require_big_gpu_with_torch_cuda(test_case): ...@@ -406,17 +405,17 @@ def require_big_gpu_with_torch_cuda(test_case):
etc. etc.
""" """
if not is_torch_available(): if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case) return pytest.mark.skip(reason="test requires PyTorch")(test_case)
import torch import torch
if not torch.cuda.is_available(): if not torch.cuda.is_available():
return unittest.skip("test requires PyTorch CUDA")(test_case) return pytest.mark.skip(reason="test requires PyTorch CUDA")(test_case)
device_properties = torch.cuda.get_device_properties(0) device_properties = torch.cuda.get_device_properties(0)
total_memory = device_properties.total_memory / (1024**3) total_memory = device_properties.total_memory / (1024**3)
return unittest.skipUnless( return pytest.mark.skipif(
total_memory >= BIG_GPU_MEMORY, f"test requires a GPU with at least {BIG_GPU_MEMORY} GB memory" total_memory < BIG_GPU_MEMORY, reason=f"test requires a GPU with at least {BIG_GPU_MEMORY} GB memory"
)(test_case) )(test_case)
...@@ -430,12 +429,12 @@ def require_big_accelerator(test_case): ...@@ -430,12 +429,12 @@ def require_big_accelerator(test_case):
test_case = pytest.mark.big_accelerator(test_case) test_case = pytest.mark.big_accelerator(test_case)
if not is_torch_available(): if not is_torch_available():
return unittest.skip("test requires PyTorch")(test_case) return pytest.mark.skip(reason="test requires PyTorch")(test_case)
import torch import torch
if not (torch.cuda.is_available() or torch.xpu.is_available()): if not (torch.cuda.is_available() or torch.xpu.is_available()):
return unittest.skip("test requires PyTorch CUDA")(test_case) return pytest.mark.skip(reason="test requires PyTorch CUDA")(test_case)
if torch.xpu.is_available(): if torch.xpu.is_available():
device_properties = torch.xpu.get_device_properties(0) device_properties = torch.xpu.get_device_properties(0)
...@@ -443,30 +442,30 @@ def require_big_accelerator(test_case): ...@@ -443,30 +442,30 @@ def require_big_accelerator(test_case):
device_properties = torch.cuda.get_device_properties(0) device_properties = torch.cuda.get_device_properties(0)
total_memory = device_properties.total_memory / (1024**3) total_memory = device_properties.total_memory / (1024**3)
return unittest.skipUnless( return pytest.mark.skipif(
total_memory >= BIG_GPU_MEMORY, total_memory < BIG_GPU_MEMORY,
f"test requires a hardware accelerator with at least {BIG_GPU_MEMORY} GB memory", reason=f"test requires a hardware accelerator with at least {BIG_GPU_MEMORY} GB memory",
)(test_case) )(test_case)
def require_torch_accelerator_with_training(test_case): def require_torch_accelerator_with_training(test_case):
"""Decorator marking a test that requires an accelerator with support for training.""" """Decorator marking a test that requires an accelerator with support for training."""
return unittest.skipUnless( return pytest.mark.skipif(
is_torch_available() and backend_supports_training(torch_device), not (is_torch_available() and backend_supports_training(torch_device)),
"test requires accelerator with training support", reason="test requires accelerator with training support",
)(test_case) )(test_case)
def skip_mps(test_case): def skip_mps(test_case):
"""Decorator marking a test to skip if torch_device is 'mps'""" """Decorator marking a test to skip if torch_device is 'mps'"""
return unittest.skipUnless(torch_device != "mps", "test requires non 'mps' device")(test_case) return pytest.mark.skipif(torch_device == "mps", reason="test requires non 'mps' device")(test_case)
def require_flax(test_case): def require_flax(test_case):
""" """
Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed Decorator marking a test that requires JAX & Flax. These tests are skipped when one / both are not installed
""" """
return unittest.skipUnless(is_flax_available(), "test requires JAX & Flax")(test_case) return pytest.mark.skipif(not is_flax_available(), reason="test requires JAX & Flax")(test_case)
def require_compel(test_case): def require_compel(test_case):
...@@ -474,21 +473,21 @@ def require_compel(test_case): ...@@ -474,21 +473,21 @@ def require_compel(test_case):
Decorator marking a test that requires compel: https://github.com/damian0815/compel. These tests are skipped when Decorator marking a test that requires compel: https://github.com/damian0815/compel. These tests are skipped when
the library is not installed. the library is not installed.
""" """
return unittest.skipUnless(is_compel_available(), "test requires compel")(test_case) return pytest.mark.skipif(not is_compel_available(), reason="test requires compel")(test_case)
def require_onnxruntime(test_case): def require_onnxruntime(test_case):
""" """
Decorator marking a test that requires onnxruntime. These tests are skipped when onnxruntime isn't installed. Decorator marking a test that requires onnxruntime. These tests are skipped when onnxruntime isn't installed.
""" """
return unittest.skipUnless(is_onnx_available(), "test requires onnxruntime")(test_case) return pytest.mark.skipif(not is_onnx_available(), reason="test requires onnxruntime")(test_case)
def require_note_seq(test_case): def require_note_seq(test_case):
""" """
Decorator marking a test that requires note_seq. These tests are skipped when note_seq isn't installed. Decorator marking a test that requires note_seq. These tests are skipped when note_seq isn't installed.
""" """
return unittest.skipUnless(is_note_seq_available(), "test requires note_seq")(test_case) return pytest.mark.skipif(not is_note_seq_available(), reason="test requires note_seq")(test_case)
def require_accelerator(test_case): def require_accelerator(test_case):
...@@ -496,14 +495,14 @@ def require_accelerator(test_case): ...@@ -496,14 +495,14 @@ def require_accelerator(test_case):
Decorator marking a test that requires a hardware accelerator backend. These tests are skipped when there are no Decorator marking a test that requires a hardware accelerator backend. These tests are skipped when there are no
hardware accelerator available. hardware accelerator available.
""" """
return unittest.skipUnless(torch_device != "cpu", "test requires a hardware accelerator")(test_case) return pytest.mark.skipif(torch_device == "cpu", reason="test requires a hardware accelerator")(test_case)
def require_torchsde(test_case): def require_torchsde(test_case):
""" """
Decorator marking a test that requires torchsde. These tests are skipped when torchsde isn't installed. Decorator marking a test that requires torchsde. These tests are skipped when torchsde isn't installed.
""" """
return unittest.skipUnless(is_torchsde_available(), "test requires torchsde")(test_case) return pytest.mark.skipif(not is_torchsde_available(), reason="test requires torchsde")(test_case)
def require_peft_backend(test_case): def require_peft_backend(test_case):
...@@ -511,35 +510,35 @@ def require_peft_backend(test_case): ...@@ -511,35 +510,35 @@ def require_peft_backend(test_case):
Decorator marking a test that requires PEFT backend, this would require some specific versions of PEFT and Decorator marking a test that requires PEFT backend, this would require some specific versions of PEFT and
transformers. transformers.
""" """
return unittest.skipUnless(USE_PEFT_BACKEND, "test requires PEFT backend")(test_case) return pytest.mark.skipif(not USE_PEFT_BACKEND, reason="test requires PEFT backend")(test_case)
def require_timm(test_case): def require_timm(test_case):
""" """
Decorator marking a test that requires timm. These tests are skipped when timm isn't installed. Decorator marking a test that requires timm. These tests are skipped when timm isn't installed.
""" """
return unittest.skipUnless(is_timm_available(), "test requires timm")(test_case) return pytest.mark.skipif(not is_timm_available(), reason="test requires timm")(test_case)
def require_bitsandbytes(test_case): def require_bitsandbytes(test_case):
""" """
Decorator marking a test that requires bitsandbytes. These tests are skipped when bitsandbytes isn't installed. Decorator marking a test that requires bitsandbytes. These tests are skipped when bitsandbytes isn't installed.
""" """
return unittest.skipUnless(is_bitsandbytes_available(), "test requires bitsandbytes")(test_case) return pytest.mark.skipif(not is_bitsandbytes_available(), reason="test requires bitsandbytes")(test_case)
def require_quanto(test_case): def require_quanto(test_case):
""" """
Decorator marking a test that requires quanto. These tests are skipped when quanto isn't installed. Decorator marking a test that requires quanto. These tests are skipped when quanto isn't installed.
""" """
return unittest.skipUnless(is_optimum_quanto_available(), "test requires quanto")(test_case) return pytest.mark.skipif(not is_optimum_quanto_available(), reason="test requires quanto")(test_case)
def require_accelerate(test_case): def require_accelerate(test_case):
""" """
Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed. Decorator marking a test that requires accelerate. These tests are skipped when accelerate isn't installed.
""" """
return unittest.skipUnless(is_accelerate_available(), "test requires accelerate")(test_case) return pytest.mark.skipif(not is_accelerate_available(), reason="test requires accelerate")(test_case)
def require_peft_version_greater(peft_version): def require_peft_version_greater(peft_version):
...@@ -552,8 +551,8 @@ def require_peft_version_greater(peft_version): ...@@ -552,8 +551,8 @@ def require_peft_version_greater(peft_version):
correct_peft_version = is_peft_available() and version.parse( correct_peft_version = is_peft_available() and version.parse(
version.parse(importlib.metadata.version("peft")).base_version version.parse(importlib.metadata.version("peft")).base_version
) > version.parse(peft_version) ) > version.parse(peft_version)
return unittest.skipUnless( return pytest.mark.skipif(
correct_peft_version, f"test requires PEFT backend with the version greater than {peft_version}" not correct_peft_version, reason=f"test requires PEFT backend with the version greater than {peft_version}"
)(test_case) )(test_case)
return decorator return decorator
...@@ -569,9 +568,9 @@ def require_transformers_version_greater(transformers_version): ...@@ -569,9 +568,9 @@ def require_transformers_version_greater(transformers_version):
correct_transformers_version = is_transformers_available() and version.parse( correct_transformers_version = is_transformers_available() and version.parse(
version.parse(importlib.metadata.version("transformers")).base_version version.parse(importlib.metadata.version("transformers")).base_version
) > version.parse(transformers_version) ) > version.parse(transformers_version)
return unittest.skipUnless( return pytest.mark.skipif(
correct_transformers_version, not correct_transformers_version,
f"test requires transformers with the version greater than {transformers_version}", reason=f"test requires transformers with the version greater than {transformers_version}",
)(test_case) )(test_case)
return decorator return decorator
...@@ -582,8 +581,9 @@ def require_accelerate_version_greater(accelerate_version): ...@@ -582,8 +581,9 @@ def require_accelerate_version_greater(accelerate_version):
correct_accelerate_version = is_accelerate_available() and version.parse( correct_accelerate_version = is_accelerate_available() and version.parse(
version.parse(importlib.metadata.version("accelerate")).base_version version.parse(importlib.metadata.version("accelerate")).base_version
) > version.parse(accelerate_version) ) > version.parse(accelerate_version)
return unittest.skipUnless( return pytest.mark.skipif(
correct_accelerate_version, f"Test requires accelerate with the version greater than {accelerate_version}." not correct_accelerate_version,
reason=f"Test requires accelerate with the version greater than {accelerate_version}.",
)(test_case) )(test_case)
return decorator return decorator
...@@ -594,8 +594,8 @@ def require_bitsandbytes_version_greater(bnb_version): ...@@ -594,8 +594,8 @@ def require_bitsandbytes_version_greater(bnb_version):
correct_bnb_version = is_bitsandbytes_available() and version.parse( correct_bnb_version = is_bitsandbytes_available() and version.parse(
version.parse(importlib.metadata.version("bitsandbytes")).base_version version.parse(importlib.metadata.version("bitsandbytes")).base_version
) > version.parse(bnb_version) ) > version.parse(bnb_version)
return unittest.skipUnless( return pytest.mark.skipif(
correct_bnb_version, f"Test requires bitsandbytes with the version greater than {bnb_version}." not correct_bnb_version, reason=f"Test requires bitsandbytes with the version greater than {bnb_version}."
)(test_case) )(test_case)
return decorator return decorator
...@@ -606,8 +606,9 @@ def require_hf_hub_version_greater(hf_hub_version): ...@@ -606,8 +606,9 @@ def require_hf_hub_version_greater(hf_hub_version):
correct_hf_hub_version = version.parse( correct_hf_hub_version = version.parse(
version.parse(importlib.metadata.version("huggingface_hub")).base_version version.parse(importlib.metadata.version("huggingface_hub")).base_version
) > version.parse(hf_hub_version) ) > version.parse(hf_hub_version)
return unittest.skipUnless( return pytest.mark.skipif(
correct_hf_hub_version, f"Test requires huggingface_hub with the version greater than {hf_hub_version}." not correct_hf_hub_version,
reason=f"Test requires huggingface_hub with the version greater than {hf_hub_version}.",
)(test_case) )(test_case)
return decorator return decorator
...@@ -618,8 +619,8 @@ def require_gguf_version_greater_or_equal(gguf_version): ...@@ -618,8 +619,8 @@ def require_gguf_version_greater_or_equal(gguf_version):
correct_gguf_version = is_gguf_available() and version.parse( correct_gguf_version = is_gguf_available() and version.parse(
version.parse(importlib.metadata.version("gguf")).base_version version.parse(importlib.metadata.version("gguf")).base_version
) >= version.parse(gguf_version) ) >= version.parse(gguf_version)
return unittest.skipUnless( return pytest.mark.skipif(
correct_gguf_version, f"Test requires gguf with the version greater than {gguf_version}." not correct_gguf_version, reason=f"Test requires gguf with the version greater than {gguf_version}."
)(test_case) )(test_case)
return decorator return decorator
...@@ -630,8 +631,8 @@ def require_torchao_version_greater_or_equal(torchao_version): ...@@ -630,8 +631,8 @@ def require_torchao_version_greater_or_equal(torchao_version):
correct_torchao_version = is_torchao_available() and version.parse( correct_torchao_version = is_torchao_available() and version.parse(
version.parse(importlib.metadata.version("torchao")).base_version version.parse(importlib.metadata.version("torchao")).base_version
) >= version.parse(torchao_version) ) >= version.parse(torchao_version)
return unittest.skipUnless( return pytest.mark.skipif(
correct_torchao_version, f"Test requires torchao with version greater than {torchao_version}." not correct_torchao_version, reason=f"Test requires torchao with version greater than {torchao_version}."
)(test_case) )(test_case)
return decorator return decorator
...@@ -642,8 +643,8 @@ def require_kernels_version_greater_or_equal(kernels_version): ...@@ -642,8 +643,8 @@ def require_kernels_version_greater_or_equal(kernels_version):
correct_kernels_version = is_kernels_available() and version.parse( correct_kernels_version = is_kernels_available() and version.parse(
version.parse(importlib.metadata.version("kernels")).base_version version.parse(importlib.metadata.version("kernels")).base_version
) >= version.parse(kernels_version) ) >= version.parse(kernels_version)
return unittest.skipUnless( return pytest.mark.skipif(
correct_kernels_version, f"Test requires kernels with version greater than {kernels_version}." not correct_kernels_version, reason=f"Test requires kernels with version greater than {kernels_version}."
)(test_case) )(test_case)
return decorator return decorator
...@@ -653,7 +654,7 @@ def deprecate_after_peft_backend(test_case): ...@@ -653,7 +654,7 @@ def deprecate_after_peft_backend(test_case):
""" """
Decorator marking a test that will be skipped after PEFT backend Decorator marking a test that will be skipped after PEFT backend
""" """
return unittest.skipUnless(not USE_PEFT_BACKEND, "test skipped in favor of PEFT backend")(test_case) return pytest.mark.skipif(USE_PEFT_BACKEND, reason="test skipped in favor of PEFT backend")(test_case)
def get_python_version(): def get_python_version():
...@@ -1064,8 +1065,8 @@ def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None): ...@@ -1064,8 +1065,8 @@ def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
To run a test in a subprocess. In particular, this can avoid (GPU) memory issue. To run a test in a subprocess. In particular, this can avoid (GPU) memory issue.
Args: Args:
test_case (`unittest.TestCase`): test_case:
The test that will run `target_func`. The test case object that will run `target_func`.
target_func (`Callable`): target_func (`Callable`):
The function implementing the actual testing logic. The function implementing the actual testing logic.
inputs (`dict`, *optional*, defaults to `None`): inputs (`dict`, *optional*, defaults to `None`):
...@@ -1083,7 +1084,7 @@ def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None): ...@@ -1083,7 +1084,7 @@ def run_test_in_subprocess(test_case, target_func, inputs=None, timeout=None):
input_queue = ctx.Queue(1) input_queue = ctx.Queue(1)
output_queue = ctx.JoinableQueue(1) output_queue = ctx.JoinableQueue(1)
# We can't send `unittest.TestCase` to the child, otherwise we get issues regarding pickle. # We can't send test case objects to the child, otherwise we get issues regarding pickle.
input_queue.put(inputs, timeout=timeout) input_queue.put(inputs, timeout=timeout)
process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout)) process = ctx.Process(target=target_func, args=(input_queue, output_queue, timeout))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment