Unverified Commit 33e636ce authored by Yao Matrix's avatar Yao Matrix Committed by GitHub
Browse files

enable torchao test cases on XPU and switch to device agnostic APIs for test cases (#11654)



* enable torchao cases on XPU
Signed-off-by: default avatarMatrix YAO <matrix.yao@intel.com>

* device agnostic APIs
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>

* more
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>

* fix style
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>

* enable test_torch_compile_recompilation_and_graph_break on XPU
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>

* resolve comments
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>

---------
Signed-off-by: default avatarMatrix YAO <matrix.yao@intel.com>
Signed-off-by: default avatarYAO Matrix <matrix.yao@intel.com>
parent e27142ac
...@@ -493,7 +493,7 @@ class TorchAoConfig(QuantizationConfigMixin): ...@@ -493,7 +493,7 @@ class TorchAoConfig(QuantizationConfigMixin):
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method() TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys(): if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp") is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp")
if is_floating_quant_type and not self._is_cuda_capability_atleast_8_9(): if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
raise ValueError( raise ValueError(
f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You " f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`." f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
...@@ -645,7 +645,7 @@ class TorchAoConfig(QuantizationConfigMixin): ...@@ -645,7 +645,7 @@ class TorchAoConfig(QuantizationConfigMixin):
QUANTIZATION_TYPES.update(INT8_QUANTIZATION_TYPES) QUANTIZATION_TYPES.update(INT8_QUANTIZATION_TYPES)
QUANTIZATION_TYPES.update(UINTX_QUANTIZATION_DTYPES) QUANTIZATION_TYPES.update(UINTX_QUANTIZATION_DTYPES)
if cls._is_cuda_capability_atleast_8_9(): if cls._is_xpu_or_cuda_capability_atleast_8_9():
QUANTIZATION_TYPES.update(FLOATX_QUANTIZATION_TYPES) QUANTIZATION_TYPES.update(FLOATX_QUANTIZATION_TYPES)
return QUANTIZATION_TYPES return QUANTIZATION_TYPES
...@@ -655,14 +655,16 @@ class TorchAoConfig(QuantizationConfigMixin): ...@@ -655,14 +655,16 @@ class TorchAoConfig(QuantizationConfigMixin):
) )
@staticmethod @staticmethod
def _is_cuda_capability_atleast_8_9() -> bool: def _is_xpu_or_cuda_capability_atleast_8_9() -> bool:
if not torch.cuda.is_available(): if torch.cuda.is_available():
raise RuntimeError("TorchAO requires a CUDA compatible GPU and installation of PyTorch.") major, minor = torch.cuda.get_device_capability()
if major == 8:
major, minor = torch.cuda.get_device_capability() return minor >= 9
if major == 8: return major >= 9
return minor >= 9 elif torch.xpu.is_available():
return major >= 9 return True
else:
raise RuntimeError("TorchAO requires a CUDA compatible GPU or Intel XPU and installation of PyTorch.")
def get_apply_tensor_subclass(self): def get_apply_tensor_subclass(self):
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method() TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
......
...@@ -300,9 +300,7 @@ def require_torch_gpu(test_case): ...@@ -300,9 +300,7 @@ def require_torch_gpu(test_case):
def require_torch_cuda_compatibility(expected_compute_capability): def require_torch_cuda_compatibility(expected_compute_capability):
def decorator(test_case): def decorator(test_case):
if not torch.cuda.is_available(): if torch.cuda.is_available():
return unittest.skip(test_case)
else:
current_compute_capability = get_torch_cuda_device_capability() current_compute_capability = get_torch_cuda_device_capability()
return unittest.skipUnless( return unittest.skipUnless(
float(current_compute_capability) == float(expected_compute_capability), float(current_compute_capability) == float(expected_compute_capability),
......
...@@ -21,6 +21,7 @@ import torch ...@@ -21,6 +21,7 @@ import torch
from diffusers import ConsistencyDecoderVAE, StableDiffusionPipeline from diffusers import ConsistencyDecoderVAE, StableDiffusionPipeline
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism, enable_full_determinism,
load_image, load_image,
slow, slow,
...@@ -162,13 +163,13 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase): ...@@ -162,13 +163,13 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
# clean up the VRAM before each test # clean up the VRAM before each test
super().setUp() super().setUp()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def tearDown(self): def tearDown(self):
# clean up the VRAM after each test # clean up the VRAM after each test
super().tearDown() super().tearDown()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
@torch.no_grad() @torch.no_grad()
def test_encode_decode(self): def test_encode_decode(self):
......
...@@ -22,6 +22,7 @@ import torch ...@@ -22,6 +22,7 @@ import torch
from diffusers import UNet2DModel from diffusers import UNet2DModel
from diffusers.utils import logging from diffusers.utils import logging
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism, enable_full_determinism,
floats_tensor, floats_tensor,
require_torch_accelerator, require_torch_accelerator,
...@@ -229,7 +230,7 @@ class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): ...@@ -229,7 +230,7 @@ class UNetLDMModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
# two models don't need to stay in the device at the same time # two models don't need to stay in the device at the same time
del model_accelerate del model_accelerate
torch.cuda.empty_cache() backend_empty_cache(torch_device)
gc.collect() gc.collect()
model_normal_load, _ = UNet2DModel.from_pretrained( model_normal_load, _ = UNet2DModel.from_pretrained(
......
...@@ -46,7 +46,6 @@ from diffusers.utils.testing_utils import ( ...@@ -46,7 +46,6 @@ from diffusers.utils.testing_utils import (
require_peft_backend, require_peft_backend,
require_torch_accelerator, require_torch_accelerator,
require_torch_accelerator_with_fp16, require_torch_accelerator_with_fp16,
require_torch_gpu,
skip_mps, skip_mps,
slow, slow,
torch_all_close, torch_all_close,
...@@ -978,13 +977,13 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -978,13 +977,13 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4) assert sample2.allclose(sample5, atol=1e-4, rtol=1e-4)
assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4) assert sample2.allclose(sample6, atol=1e-4, rtol=1e-4)
@require_torch_gpu
@parameterized.expand( @parameterized.expand(
[ [
("hf-internal-testing/unet2d-sharded-dummy", None), ("hf-internal-testing/unet2d-sharded-dummy", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"), ("hf-internal-testing/tiny-sd-unet-sharded-latest-format", "fp16"),
] ]
) )
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub(self, repo_id, variant): def test_load_sharded_checkpoint_from_hub(self, repo_id, variant):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common() _, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(repo_id, variant=variant) loaded_model = self.model_class.from_pretrained(repo_id, variant=variant)
...@@ -994,13 +993,13 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test ...@@ -994,13 +993,13 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
assert loaded_model assert loaded_model
assert new_output.sample.shape == (4, 4, 16, 16) assert new_output.sample.shape == (4, 4, 16, 16)
@require_torch_gpu
@parameterized.expand( @parameterized.expand(
[ [
("hf-internal-testing/unet2d-sharded-dummy-subfolder", None), ("hf-internal-testing/unet2d-sharded-dummy-subfolder", None),
("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"), ("hf-internal-testing/tiny-sd-unet-sharded-latest-format-subfolder", "fp16"),
] ]
) )
@require_torch_accelerator
def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant): def test_load_sharded_checkpoint_from_hub_subfolder(self, repo_id, variant):
_, inputs_dict = self.prepare_init_args_and_inputs_for_common() _, inputs_dict = self.prepare_init_args_and_inputs_for_common()
loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant) loaded_model = self.model_class.from_pretrained(repo_id, subfolder="unet", variant=variant)
......
...@@ -24,6 +24,7 @@ from transformers import AutoTokenizer, T5Config, T5EncoderModel ...@@ -24,6 +24,7 @@ from transformers import AutoTokenizer, T5Config, T5EncoderModel
from diffusers import AllegroPipeline, AllegroTransformer3DModel, AutoencoderKLAllegro, DDIMScheduler from diffusers import AllegroPipeline, AllegroTransformer3DModel, AutoencoderKLAllegro, DDIMScheduler
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism, enable_full_determinism,
numpy_cosine_similarity_distance, numpy_cosine_similarity_distance,
require_hf_hub_version_greater, require_hf_hub_version_greater,
...@@ -341,12 +342,12 @@ class AllegroPipelineIntegrationTests(unittest.TestCase): ...@@ -341,12 +342,12 @@ class AllegroPipelineIntegrationTests(unittest.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def test_allegro(self): def test_allegro(self):
generator = torch.Generator("cpu").manual_seed(0) generator = torch.Generator("cpu").manual_seed(0)
......
...@@ -37,7 +37,7 @@ from diffusers import ( ...@@ -37,7 +37,7 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.utils import is_xformers_available from diffusers.utils import is_xformers_available
from diffusers.utils.testing_utils import enable_full_determinism, nightly, torch_device from diffusers.utils.testing_utils import backend_empty_cache, enable_full_determinism, nightly, torch_device
from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
from ..test_pipelines_common import PipelineTesterMixin from ..test_pipelines_common import PipelineTesterMixin
...@@ -378,12 +378,12 @@ class AudioLDMPipelineSlowTests(unittest.TestCase): ...@@ -378,12 +378,12 @@ class AudioLDMPipelineSlowTests(unittest.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0): def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed) generator = torch.Generator(device=generator_device).manual_seed(seed)
...@@ -423,12 +423,12 @@ class AudioLDMPipelineNightlyTests(unittest.TestCase): ...@@ -423,12 +423,12 @@ class AudioLDMPipelineNightlyTests(unittest.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0): def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed) generator = torch.Generator(device=generator_device).manual_seed(seed)
......
...@@ -45,7 +45,13 @@ from diffusers import ( ...@@ -45,7 +45,13 @@ from diffusers import (
LMSDiscreteScheduler, LMSDiscreteScheduler,
PNDMScheduler, PNDMScheduler,
) )
from diffusers.utils.testing_utils import enable_full_determinism, is_torch_version, nightly, torch_device from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
is_torch_version,
nightly,
torch_device,
)
from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS, TEXT_TO_AUDIO_PARAMS
from ..test_pipelines_common import PipelineTesterMixin from ..test_pipelines_common import PipelineTesterMixin
...@@ -540,12 +546,12 @@ class AudioLDM2PipelineSlowTests(unittest.TestCase): ...@@ -540,12 +546,12 @@ class AudioLDM2PipelineSlowTests(unittest.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0): def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
generator = torch.Generator(device=generator_device).manual_seed(seed) generator = torch.Generator(device=generator_device).manual_seed(seed)
......
...@@ -22,6 +22,7 @@ from transformers import AutoTokenizer, T5EncoderModel ...@@ -22,6 +22,7 @@ from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransformer3DModel, DDIMScheduler from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransformer3DModel, DDIMScheduler
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism, enable_full_determinism,
numpy_cosine_similarity_distance, numpy_cosine_similarity_distance,
require_torch_accelerator, require_torch_accelerator,
...@@ -334,12 +335,12 @@ class CogVideoXPipelineIntegrationTests(unittest.TestCase): ...@@ -334,12 +335,12 @@ class CogVideoXPipelineIntegrationTests(unittest.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def test_cogvideox(self): def test_cogvideox(self):
generator = torch.Generator("cpu").manual_seed(0) generator = torch.Generator("cpu").manual_seed(0)
......
...@@ -22,6 +22,7 @@ from transformers import AutoTokenizer, T5EncoderModel ...@@ -22,6 +22,7 @@ from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, CogVideoXDDIMScheduler, CogView3PlusPipeline, CogView3PlusTransformer2DModel from diffusers import AutoencoderKL, CogVideoXDDIMScheduler, CogView3PlusPipeline, CogView3PlusTransformer2DModel
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism, enable_full_determinism,
numpy_cosine_similarity_distance, numpy_cosine_similarity_distance,
require_torch_accelerator, require_torch_accelerator,
...@@ -244,12 +245,12 @@ class CogView3PlusPipelineIntegrationTests(unittest.TestCase): ...@@ -244,12 +245,12 @@ class CogView3PlusPipelineIntegrationTests(unittest.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def test_cogview3plus(self): def test_cogview3plus(self):
generator = torch.Generator("cpu").manual_seed(0) generator = torch.Generator("cpu").manual_seed(0)
......
...@@ -36,6 +36,7 @@ from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetMo ...@@ -36,6 +36,7 @@ from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetMo
from diffusers.utils import load_image from diffusers.utils import load_image
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism, enable_full_determinism,
floats_tensor, floats_tensor,
load_numpy, load_numpy,
...@@ -412,12 +413,12 @@ class ControlNetImg2ImgPipelineSlowTests(unittest.TestCase): ...@@ -412,12 +413,12 @@ class ControlNetImg2ImgPipelineSlowTests(unittest.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def test_canny(self): def test_canny(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny") controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
......
...@@ -36,6 +36,7 @@ from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetMo ...@@ -36,6 +36,7 @@ from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetMo
from diffusers.utils import load_image from diffusers.utils import load_image
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism, enable_full_determinism,
floats_tensor, floats_tensor,
load_numpy, load_numpy,
...@@ -464,12 +465,12 @@ class ControlNetInpaintPipelineSlowTests(unittest.TestCase): ...@@ -464,12 +465,12 @@ class ControlNetInpaintPipelineSlowTests(unittest.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def test_canny(self): def test_canny(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny") controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
......
...@@ -221,7 +221,7 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes ...@@ -221,7 +221,7 @@ class StableDiffusion3ControlNetPipelineFastTests(unittest.TestCase, PipelineTes
@slow @slow
@require_big_accelerator @require_big_accelerator
@pytest.mark.big_gpu_with_torch_cuda @pytest.mark.big_accelerator
class StableDiffusion3ControlNetPipelineSlowTests(unittest.TestCase): class StableDiffusion3ControlNetPipelineSlowTests(unittest.TestCase):
pipeline_class = StableDiffusion3ControlNetPipeline pipeline_class = StableDiffusion3ControlNetPipeline
......
...@@ -25,6 +25,7 @@ from diffusers.models.attention_processor import AttnAddedKVProcessor ...@@ -25,6 +25,7 @@ from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
backend_empty_cache, backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated, backend_reset_max_memory_allocated,
backend_reset_peak_memory_stats, backend_reset_peak_memory_stats,
load_numpy, load_numpy,
...@@ -135,7 +136,7 @@ class IFPipelineSlowTests(unittest.TestCase): ...@@ -135,7 +136,7 @@ class IFPipelineSlowTests(unittest.TestCase):
image = output.images[0] image = output.images[0]
mem_bytes = torch.cuda.max_memory_allocated() mem_bytes = backend_max_memory_allocated(torch_device)
assert mem_bytes < 12 * 10**9 assert mem_bytes < 12 * 10**9
expected_image = load_numpy( expected_image = load_numpy(
......
...@@ -24,6 +24,7 @@ from diffusers.models.attention_processor import AttnAddedKVProcessor ...@@ -24,6 +24,7 @@ from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
backend_empty_cache, backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated, backend_reset_max_memory_allocated,
backend_reset_peak_memory_stats, backend_reset_peak_memory_stats,
floats_tensor, floats_tensor,
...@@ -151,7 +152,7 @@ class IFImg2ImgPipelineSlowTests(unittest.TestCase): ...@@ -151,7 +152,7 @@ class IFImg2ImgPipelineSlowTests(unittest.TestCase):
) )
image = output.images[0] image = output.images[0]
mem_bytes = torch.cuda.max_memory_allocated() mem_bytes = backend_max_memory_allocated(torch_device)
assert mem_bytes < 12 * 10**9 assert mem_bytes < 12 * 10**9
expected_image = load_numpy( expected_image = load_numpy(
......
...@@ -224,7 +224,7 @@ class FluxPipelineFastTests( ...@@ -224,7 +224,7 @@ class FluxPipelineFastTests(
@nightly @nightly
@require_big_accelerator @require_big_accelerator
@pytest.mark.big_gpu_with_torch_cuda @pytest.mark.big_accelerator
class FluxPipelineSlowTests(unittest.TestCase): class FluxPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxPipeline pipeline_class = FluxPipeline
repo_id = "black-forest-labs/FLUX.1-schnell" repo_id = "black-forest-labs/FLUX.1-schnell"
...@@ -312,7 +312,7 @@ class FluxPipelineSlowTests(unittest.TestCase): ...@@ -312,7 +312,7 @@ class FluxPipelineSlowTests(unittest.TestCase):
@slow @slow
@require_big_accelerator @require_big_accelerator
@pytest.mark.big_gpu_with_torch_cuda @pytest.mark.big_accelerator
class FluxIPAdapterPipelineSlowTests(unittest.TestCase): class FluxIPAdapterPipelineSlowTests(unittest.TestCase):
pipeline_class = FluxPipeline pipeline_class = FluxPipeline
repo_id = "black-forest-labs/FLUX.1-dev" repo_id = "black-forest-labs/FLUX.1-dev"
......
...@@ -19,7 +19,7 @@ from diffusers.utils.testing_utils import ( ...@@ -19,7 +19,7 @@ from diffusers.utils.testing_utils import (
@slow @slow
@require_big_accelerator @require_big_accelerator
@pytest.mark.big_gpu_with_torch_cuda @pytest.mark.big_accelerator
class FluxReduxSlowTests(unittest.TestCase): class FluxReduxSlowTests(unittest.TestCase):
pipeline_class = FluxPriorReduxPipeline pipeline_class = FluxPriorReduxPipeline
repo_id = "black-forest-labs/FLUX.1-Redux-dev" repo_id = "black-forest-labs/FLUX.1-Redux-dev"
......
...@@ -23,6 +23,7 @@ from transformers import AutoTokenizer, BertModel, T5EncoderModel ...@@ -23,6 +23,7 @@ from transformers import AutoTokenizer, BertModel, T5EncoderModel
from diffusers import AutoencoderKL, DDPMScheduler, HunyuanDiT2DModel, HunyuanDiTPipeline from diffusers import AutoencoderKL, DDPMScheduler, HunyuanDiT2DModel, HunyuanDiTPipeline
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism, enable_full_determinism,
numpy_cosine_similarity_distance, numpy_cosine_similarity_distance,
require_torch_accelerator, require_torch_accelerator,
...@@ -310,12 +311,12 @@ class HunyuanDiTPipelineIntegrationTests(unittest.TestCase): ...@@ -310,12 +311,12 @@ class HunyuanDiTPipelineIntegrationTests(unittest.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def test_hunyuan_dit_1024(self): def test_hunyuan_dit_1024(self):
generator = torch.Generator("cpu").manual_seed(0) generator = torch.Generator("cpu").manual_seed(0)
......
...@@ -27,6 +27,7 @@ from diffusers import ( ...@@ -27,6 +27,7 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
backend_empty_cache,
load_numpy, load_numpy,
nightly, nightly,
numpy_cosine_similarity_distance, numpy_cosine_similarity_distance,
...@@ -231,12 +232,12 @@ class StableDiffusionAttendAndExcitePipelineIntegrationTests(unittest.TestCase): ...@@ -231,12 +232,12 @@ class StableDiffusionAttendAndExcitePipelineIntegrationTests(unittest.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def test_attend_and_excite_fp16(self): def test_attend_and_excite_fp16(self):
generator = torch.manual_seed(51) generator = torch.manual_seed(51)
......
...@@ -25,6 +25,7 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer ...@@ -25,6 +25,7 @@ from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, PNDMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel from diffusers import AutoencoderKL, PNDMScheduler, StableDiffusionInpaintPipeline, UNet2DConditionModel
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
backend_empty_cache, backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated, backend_reset_max_memory_allocated,
backend_reset_peak_memory_stats, backend_reset_peak_memory_stats,
enable_full_determinism, enable_full_determinism,
...@@ -287,6 +288,6 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase): ...@@ -287,6 +288,6 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase):
output_type="np", output_type="np",
) )
mem_bytes = torch.cuda.max_memory_allocated() mem_bytes = backend_max_memory_allocated(torch_device)
# make sure that less than 2.65 GB is allocated # make sure that less than 2.65 GB is allocated
assert mem_bytes < 2.65 * 10**9 assert mem_bytes < 2.65 * 10**9
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment