Unverified Commit 2f0f281b authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Tests] restrict memory tests for quanto for certain schemes. (#11052)



* restrict memory tests for quanto for certain schemes.

* Apply suggestions from code review
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

* fixes

* style

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent ccc83216
...@@ -101,6 +101,8 @@ if is_torch_available(): ...@@ -101,6 +101,8 @@ if is_torch_available():
mps_backend_registered = hasattr(torch.backends, "mps") mps_backend_registered = hasattr(torch.backends, "mps")
torch_device = "mps" if (mps_backend_registered and torch.backends.mps.is_available()) else torch_device torch_device = "mps" if (mps_backend_registered and torch.backends.mps.is_available()) else torch_device
from .torch_utils import get_torch_cuda_device_capability
def torch_all_close(a, b, *args, **kwargs): def torch_all_close(a, b, *args, **kwargs):
if not is_torch_available(): if not is_torch_available():
...@@ -282,6 +284,20 @@ def require_torch_gpu(test_case): ...@@ -282,6 +284,20 @@ def require_torch_gpu(test_case):
) )
def require_torch_cuda_compatibility(expected_compute_capability):
def decorator(test_case):
if not torch.cuda.is_available():
return unittest.skip(test_case)
else:
current_compute_capability = get_torch_cuda_device_capability()
return unittest.skipUnless(
float(current_compute_capability) == float(expected_compute_capability),
"Test not supported for this compute capability.",
)
return decorator
# These decorators are for accelerator-specific behaviours that are not GPU-specific # These decorators are for accelerator-specific behaviours that are not GPU-specific
def require_torch_accelerator(test_case): def require_torch_accelerator(test_case):
"""Decorator marking a test that requires an accelerator backend and PyTorch.""" """Decorator marking a test that requires an accelerator backend and PyTorch."""
......
...@@ -10,6 +10,7 @@ from diffusers.utils.testing_utils import ( ...@@ -10,6 +10,7 @@ from diffusers.utils.testing_utils import (
numpy_cosine_similarity_distance, numpy_cosine_similarity_distance,
require_accelerate, require_accelerate,
require_big_gpu_with_torch_cuda, require_big_gpu_with_torch_cuda,
require_torch_cuda_compatibility,
torch_device, torch_device,
) )
...@@ -311,6 +312,7 @@ class FluxTransformerInt8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCa ...@@ -311,6 +312,7 @@ class FluxTransformerInt8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCa
return {"weights_dtype": "int8"} return {"weights_dtype": "int8"}
@require_torch_cuda_compatibility(8.0)
class FluxTransformerInt4WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase): class FluxTransformerInt4WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
expected_memory_reduction = 0.55 expected_memory_reduction = 0.55
...@@ -318,6 +320,7 @@ class FluxTransformerInt4WeightsTest(FluxTransformerQuantoMixin, unittest.TestCa ...@@ -318,6 +320,7 @@ class FluxTransformerInt4WeightsTest(FluxTransformerQuantoMixin, unittest.TestCa
return {"weights_dtype": "int4"} return {"weights_dtype": "int4"}
@require_torch_cuda_compatibility(8.0)
class FluxTransformerInt2WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase): class FluxTransformerInt2WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
expected_memory_reduction = 0.65 expected_memory_reduction = 0.65
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment