Unverified Commit 2f0f281b authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[Tests] restrict memory tests for quanto for certain schemes. (#11052)



* restrict memory tests for quanto for certain schemes.

* Apply suggestions from code review
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>

* fixes

* style

---------
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent ccc83216
......@@ -101,6 +101,8 @@ if is_torch_available():
mps_backend_registered = hasattr(torch.backends, "mps")
torch_device = "mps" if (mps_backend_registered and torch.backends.mps.is_available()) else torch_device
from .torch_utils import get_torch_cuda_device_capability
def torch_all_close(a, b, *args, **kwargs):
if not is_torch_available():
......@@ -282,6 +284,20 @@ def require_torch_gpu(test_case):
)
def require_torch_cuda_compatibility(expected_compute_capability):
def decorator(test_case):
if not torch.cuda.is_available():
return unittest.skip(test_case)
else:
current_compute_capability = get_torch_cuda_device_capability()
return unittest.skipUnless(
float(current_compute_capability) == float(expected_compute_capability),
"Test not supported for this compute capability.",
)
return decorator
# These decorators are for accelerator-specific behaviours that are not GPU-specific
def require_torch_accelerator(test_case):
"""Decorator marking a test that requires an accelerator backend and PyTorch."""
......
......@@ -10,6 +10,7 @@ from diffusers.utils.testing_utils import (
numpy_cosine_similarity_distance,
require_accelerate,
require_big_gpu_with_torch_cuda,
require_torch_cuda_compatibility,
torch_device,
)
......@@ -311,6 +312,7 @@ class FluxTransformerInt8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCa
return {"weights_dtype": "int8"}
@require_torch_cuda_compatibility(8.0)
class FluxTransformerInt4WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
expected_memory_reduction = 0.55
......@@ -318,6 +320,7 @@ class FluxTransformerInt4WeightsTest(FluxTransformerQuantoMixin, unittest.TestCa
return {"weights_dtype": "int4"}
@require_torch_cuda_compatibility(8.0)
class FluxTransformerInt2WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
expected_memory_reduction = 0.65
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment