Unverified Commit ec37e209 authored by Fanli Lin's avatar Fanli Lin Committed by GitHub
Browse files

[tests] make tests device-agnostic (part 3) (#10437)



* initial comit

* fix empty cache

* fix one more

* fix style

* update device functions

* update

* update

* Update src/diffusers/utils/testing_utils.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/utils/testing_utils.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/utils/testing_utils.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update tests/pipelines/controlnet/test_controlnet.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/utils/testing_utils.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update src/diffusers/utils/testing_utils.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update tests/pipelines/controlnet/test_controlnet.py
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* with gc.collect

* update

* make style

* check_torch_dependencies

* add mps empty cache

* bug fix

* Apply suggestions from code review

---------
Co-authored-by: default avatarhlky <hlky@hlky.ac>
parent 158a5a87
...@@ -86,7 +86,12 @@ if is_torch_available(): ...@@ -86,7 +86,12 @@ if is_torch_available():
) from e ) from e
logger.info(f"torch_device overrode to {torch_device}") logger.info(f"torch_device overrode to {torch_device}")
else: else:
torch_device = "cuda" if torch.cuda.is_available() else "cpu" if torch.cuda.is_available():
torch_device = "cuda"
elif torch.xpu.is_available():
torch_device = "xpu"
else:
torch_device = "cpu"
is_torch_higher_equal_than_1_12 = version.parse( is_torch_higher_equal_than_1_12 = version.parse(
version.parse(torch.__version__).base_version version.parse(torch.__version__).base_version
) >= version.parse("1.12") ) >= version.parse("1.12")
...@@ -1067,12 +1072,51 @@ def _is_torch_fp64_available(device): ...@@ -1067,12 +1072,51 @@ def _is_torch_fp64_available(device):
# Guard these lookups for when Torch is not used - alternative accelerator support is for PyTorch # Guard these lookups for when Torch is not used - alternative accelerator support is for PyTorch
if is_torch_available(): if is_torch_available():
# Behaviour flags # Behaviour flags
BACKEND_SUPPORTS_TRAINING = {"cuda": True, "cpu": True, "mps": False, "default": True} BACKEND_SUPPORTS_TRAINING = {"cuda": True, "xpu": True, "cpu": True, "mps": False, "default": True}
# Function definitions # Function definitions
BACKEND_EMPTY_CACHE = {"cuda": torch.cuda.empty_cache, "cpu": None, "mps": None, "default": None} BACKEND_EMPTY_CACHE = {
BACKEND_DEVICE_COUNT = {"cuda": torch.cuda.device_count, "cpu": lambda: 0, "mps": lambda: 0, "default": 0} "cuda": torch.cuda.empty_cache,
BACKEND_MANUAL_SEED = {"cuda": torch.cuda.manual_seed, "cpu": torch.manual_seed, "default": torch.manual_seed} "xpu": torch.xpu.empty_cache,
"cpu": None,
"mps": torch.mps.empty_cache,
"default": None,
}
BACKEND_DEVICE_COUNT = {
"cuda": torch.cuda.device_count,
"xpu": torch.xpu.device_count,
"cpu": lambda: 0,
"mps": lambda: 0,
"default": 0,
}
BACKEND_MANUAL_SEED = {
"cuda": torch.cuda.manual_seed,
"xpu": torch.xpu.manual_seed,
"cpu": torch.manual_seed,
"mps": torch.mps.manual_seed,
"default": torch.manual_seed,
}
BACKEND_RESET_PEAK_MEMORY_STATS = {
"cuda": torch.cuda.reset_peak_memory_stats,
"xpu": getattr(torch.xpu, "reset_peak_memory_stats", None),
"cpu": None,
"mps": None,
"default": None,
}
BACKEND_RESET_MAX_MEMORY_ALLOCATED = {
"cuda": torch.cuda.reset_max_memory_allocated,
"xpu": None,
"cpu": None,
"mps": None,
"default": None,
}
BACKEND_MAX_MEMORY_ALLOCATED = {
"cuda": torch.cuda.max_memory_allocated,
"xpu": getattr(torch.xpu, "max_memory_allocated", None),
"cpu": 0,
"mps": 0,
"default": 0,
}
# This dispatches a defined function according to the accelerator from the function definitions. # This dispatches a defined function according to the accelerator from the function definitions.
...@@ -1103,6 +1147,18 @@ def backend_device_count(device: str): ...@@ -1103,6 +1147,18 @@ def backend_device_count(device: str):
return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT) return _device_agnostic_dispatch(device, BACKEND_DEVICE_COUNT)
def backend_reset_peak_memory_stats(device: str):
return _device_agnostic_dispatch(device, BACKEND_RESET_PEAK_MEMORY_STATS)
def backend_reset_max_memory_allocated(device: str):
return _device_agnostic_dispatch(device, BACKEND_RESET_MAX_MEMORY_ALLOCATED)
def backend_max_memory_allocated(device: str):
return _device_agnostic_dispatch(device, BACKEND_MAX_MEMORY_ALLOCATED)
# These are callables which return boolean behaviour flags and can be used to specify some # These are callables which return boolean behaviour flags and can be used to specify some
# device agnostic alternative where the feature is unsupported. # device agnostic alternative where the feature is unsupported.
def backend_supports_training(device: str): def backend_supports_training(device: str):
...@@ -1159,3 +1215,6 @@ if is_torch_available(): ...@@ -1159,3 +1215,6 @@ if is_torch_available():
update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN") update_mapping_from_spec(BACKEND_EMPTY_CACHE, "EMPTY_CACHE_FN")
update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN") update_mapping_from_spec(BACKEND_DEVICE_COUNT, "DEVICE_COUNT_FN")
update_mapping_from_spec(BACKEND_SUPPORTS_TRAINING, "SUPPORTS_TRAINING") update_mapping_from_spec(BACKEND_SUPPORTS_TRAINING, "SUPPORTS_TRAINING")
update_mapping_from_spec(BACKEND_RESET_PEAK_MEMORY_STATS, "RESET_PEAK_MEMORY_STATS_FN")
update_mapping_from_spec(BACKEND_RESET_MAX_MEMORY_ALLOCATED, "RESET_MAX_MEMORY_ALLOCATED_FN")
update_mapping_from_spec(BACKEND_MAX_MEMORY_ALLOCATED, "MAX_MEMORY_ALLOCATED_FN")
...@@ -57,8 +57,8 @@ from diffusers.utils.testing_utils import ( ...@@ -57,8 +57,8 @@ from diffusers.utils.testing_utils import (
get_python_version, get_python_version,
is_torch_compile, is_torch_compile,
require_torch_2, require_torch_2,
require_torch_accelerator,
require_torch_accelerator_with_training, require_torch_accelerator_with_training,
require_torch_gpu,
require_torch_multi_gpu, require_torch_multi_gpu,
run_test_in_subprocess, run_test_in_subprocess,
torch_all_close, torch_all_close,
...@@ -543,7 +543,7 @@ class ModelTesterMixin: ...@@ -543,7 +543,7 @@ class ModelTesterMixin:
assert torch.allclose(output, output_3, atol=self.base_precision) assert torch.allclose(output, output_3, atol=self.base_precision)
assert torch.allclose(output_2, output_3, atol=self.base_precision) assert torch.allclose(output_2, output_3, atol=self.base_precision)
@require_torch_gpu @require_torch_accelerator
def test_set_attn_processor_for_determinism(self): def test_set_attn_processor_for_determinism(self):
if self.uses_custom_attn_processor: if self.uses_custom_attn_processor:
return return
...@@ -1068,7 +1068,7 @@ class ModelTesterMixin: ...@@ -1068,7 +1068,7 @@ class ModelTesterMixin:
self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception)) self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception))
@require_torch_gpu @require_torch_accelerator
def test_cpu_offload(self): def test_cpu_offload(self):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common() config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval() model = self.model_class(**config).eval()
...@@ -1098,7 +1098,7 @@ class ModelTesterMixin: ...@@ -1098,7 +1098,7 @@ class ModelTesterMixin:
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
@require_torch_gpu @require_torch_accelerator
def test_disk_offload_without_safetensors(self): def test_disk_offload_without_safetensors(self):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common() config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval() model = self.model_class(**config).eval()
...@@ -1132,7 +1132,7 @@ class ModelTesterMixin: ...@@ -1132,7 +1132,7 @@ class ModelTesterMixin:
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
@require_torch_gpu @require_torch_accelerator
def test_disk_offload_with_safetensors(self): def test_disk_offload_with_safetensors(self):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common() config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval() model = self.model_class(**config).eval()
...@@ -1191,7 +1191,7 @@ class ModelTesterMixin: ...@@ -1191,7 +1191,7 @@ class ModelTesterMixin:
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
@require_torch_gpu @require_torch_accelerator
def test_sharded_checkpoints(self): def test_sharded_checkpoints(self):
torch.manual_seed(0) torch.manual_seed(0)
config, inputs_dict = self.prepare_init_args_and_inputs_for_common() config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
...@@ -1223,7 +1223,7 @@ class ModelTesterMixin: ...@@ -1223,7 +1223,7 @@ class ModelTesterMixin:
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
@require_torch_gpu @require_torch_accelerator
def test_sharded_checkpoints_with_variant(self): def test_sharded_checkpoints_with_variant(self):
torch.manual_seed(0) torch.manual_seed(0)
config, inputs_dict = self.prepare_init_args_and_inputs_for_common() config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
...@@ -1261,7 +1261,7 @@ class ModelTesterMixin: ...@@ -1261,7 +1261,7 @@ class ModelTesterMixin:
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5)) self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
@require_torch_gpu @require_torch_accelerator
def test_sharded_checkpoints_device_map(self): def test_sharded_checkpoints_device_map(self):
config, inputs_dict = self.prepare_init_args_and_inputs_for_common() config, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**config).eval() model = self.model_class(**config).eval()
......
...@@ -27,7 +27,7 @@ from diffusers.utils.testing_utils import ( ...@@ -27,7 +27,7 @@ from diffusers.utils.testing_utils import (
enable_full_determinism, enable_full_determinism,
numpy_cosine_similarity_distance, numpy_cosine_similarity_distance,
require_hf_hub_version_greater, require_hf_hub_version_greater,
require_torch_gpu, require_torch_accelerator,
require_transformers_version_greater, require_transformers_version_greater,
slow, slow,
torch_device, torch_device,
...@@ -332,7 +332,7 @@ class AllegroPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -332,7 +332,7 @@ class AllegroPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@slow @slow
@require_torch_gpu @require_torch_accelerator
class AllegroPipelineIntegrationTests(unittest.TestCase): class AllegroPipelineIntegrationTests(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger." prompt = "A painting of a squirrel eating a burger."
...@@ -350,7 +350,7 @@ class AllegroPipelineIntegrationTests(unittest.TestCase): ...@@ -350,7 +350,7 @@ class AllegroPipelineIntegrationTests(unittest.TestCase):
generator = torch.Generator("cpu").manual_seed(0) generator = torch.Generator("cpu").manual_seed(0)
pipe = AllegroPipeline.from_pretrained("rhymes-ai/Allegro", torch_dtype=torch.float16) pipe = AllegroPipeline.from_pretrained("rhymes-ai/Allegro", torch_dtype=torch.float16)
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload(device=torch_device)
prompt = self.prompt prompt = self.prompt
videos = pipe( videos = pipe(
......
...@@ -20,9 +20,10 @@ from diffusers import ( ...@@ -20,9 +20,10 @@ from diffusers import (
from diffusers.models.attention import FreeNoiseTransformerBlock from diffusers.models.attention import FreeNoiseTransformerBlock
from diffusers.utils import is_xformers_available, logging from diffusers.utils import is_xformers_available, logging
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
backend_empty_cache,
numpy_cosine_similarity_distance, numpy_cosine_similarity_distance,
require_accelerator, require_accelerator,
require_torch_gpu, require_torch_accelerator,
slow, slow,
torch_device, torch_device,
) )
...@@ -547,19 +548,19 @@ class AnimateDiffPipelineFastTests( ...@@ -547,19 +548,19 @@ class AnimateDiffPipelineFastTests(
@slow @slow
@require_torch_gpu @require_torch_accelerator
class AnimateDiffPipelineSlowTests(unittest.TestCase): class AnimateDiffPipelineSlowTests(unittest.TestCase):
def setUp(self): def setUp(self):
# clean up the VRAM before each test # clean up the VRAM before each test
super().setUp() super().setUp()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def tearDown(self): def tearDown(self):
# clean up the VRAM after each test # clean up the VRAM after each test
super().tearDown() super().tearDown()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def test_animatediff(self): def test_animatediff(self):
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2") adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2")
...@@ -573,7 +574,7 @@ class AnimateDiffPipelineSlowTests(unittest.TestCase): ...@@ -573,7 +574,7 @@ class AnimateDiffPipelineSlowTests(unittest.TestCase):
clip_sample=False, clip_sample=False,
) )
pipe.enable_vae_slicing() pipe.enable_vae_slicing()
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
prompt = "night, b&w photo of old house, post apocalypse, forest, storm weather, wind, rocks, 8k uhd, dslr, soft lighting, high quality, film grain" prompt = "night, b&w photo of old house, post apocalypse, forest, storm weather, wind, rocks, 8k uhd, dslr, soft lighting, high quality, film grain"
......
...@@ -24,7 +24,7 @@ from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransf ...@@ -24,7 +24,7 @@ from diffusers import AutoencoderKLCogVideoX, CogVideoXPipeline, CogVideoXTransf
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
enable_full_determinism, enable_full_determinism,
numpy_cosine_similarity_distance, numpy_cosine_similarity_distance,
require_torch_gpu, require_torch_accelerator,
slow, slow,
torch_device, torch_device,
) )
...@@ -321,7 +321,7 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -321,7 +321,7 @@ class CogVideoXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@slow @slow
@require_torch_gpu @require_torch_accelerator
class CogVideoXPipelineIntegrationTests(unittest.TestCase): class CogVideoXPipelineIntegrationTests(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger." prompt = "A painting of a squirrel eating a burger."
...@@ -339,7 +339,7 @@ class CogVideoXPipelineIntegrationTests(unittest.TestCase): ...@@ -339,7 +339,7 @@ class CogVideoXPipelineIntegrationTests(unittest.TestCase):
generator = torch.Generator("cpu").manual_seed(0) generator = torch.Generator("cpu").manual_seed(0)
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16) pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-2b", torch_dtype=torch.float16)
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload(device=torch_device)
prompt = self.prompt prompt = self.prompt
videos = pipe( videos = pipe(
......
...@@ -24,9 +24,10 @@ from transformers import AutoTokenizer, T5EncoderModel ...@@ -24,9 +24,10 @@ from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLCogVideoX, CogVideoXImageToVideoPipeline, CogVideoXTransformer3DModel, DDIMScheduler from diffusers import AutoencoderKLCogVideoX, CogVideoXImageToVideoPipeline, CogVideoXTransformer3DModel, DDIMScheduler
from diffusers.utils import load_image from diffusers.utils import load_image
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism, enable_full_determinism,
numpy_cosine_similarity_distance, numpy_cosine_similarity_distance,
require_torch_gpu, require_torch_accelerator,
slow, slow,
torch_device, torch_device,
) )
...@@ -344,25 +345,25 @@ class CogVideoXImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestC ...@@ -344,25 +345,25 @@ class CogVideoXImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestC
@slow @slow
@require_torch_gpu @require_torch_accelerator
class CogVideoXImageToVideoPipelineIntegrationTests(unittest.TestCase): class CogVideoXImageToVideoPipelineIntegrationTests(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger." prompt = "A painting of a squirrel eating a burger."
def setUp(self): def setUp(self):
super().setUp() super().setUp()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def test_cogvideox(self): def test_cogvideox(self):
generator = torch.Generator("cpu").manual_seed(0) generator = torch.Generator("cpu").manual_seed(0)
pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16) pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload(device=torch_device)
prompt = self.prompt prompt = self.prompt
image = load_image( image = load_image(
......
...@@ -24,7 +24,7 @@ from diffusers import AutoencoderKL, CogVideoXDDIMScheduler, CogView3PlusPipelin ...@@ -24,7 +24,7 @@ from diffusers import AutoencoderKL, CogVideoXDDIMScheduler, CogView3PlusPipelin
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
enable_full_determinism, enable_full_determinism,
numpy_cosine_similarity_distance, numpy_cosine_similarity_distance,
require_torch_gpu, require_torch_accelerator,
slow, slow,
torch_device, torch_device,
) )
...@@ -232,7 +232,7 @@ class CogView3PlusPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -232,7 +232,7 @@ class CogView3PlusPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@slow @slow
@require_torch_gpu @require_torch_accelerator
class CogView3PlusPipelineIntegrationTests(unittest.TestCase): class CogView3PlusPipelineIntegrationTests(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger." prompt = "A painting of a squirrel eating a burger."
...@@ -250,7 +250,7 @@ class CogView3PlusPipelineIntegrationTests(unittest.TestCase): ...@@ -250,7 +250,7 @@ class CogView3PlusPipelineIntegrationTests(unittest.TestCase):
generator = torch.Generator("cpu").manual_seed(0) generator = torch.Generator("cpu").manual_seed(0)
pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3Plus-3b", torch_dtype=torch.float16) pipe = CogView3PlusPipeline.from_pretrained("THUDM/CogView3Plus-3b", torch_dtype=torch.float16)
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload(device=torch_device)
prompt = self.prompt prompt = self.prompt
images = pipe( images = pipe(
......
...@@ -34,13 +34,17 @@ from diffusers import ( ...@@ -34,13 +34,17 @@ from diffusers import (
from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
backend_reset_peak_memory_stats,
enable_full_determinism, enable_full_determinism,
get_python_version, get_python_version,
is_torch_compile, is_torch_compile,
load_image, load_image,
load_numpy, load_numpy,
require_torch_2, require_torch_2,
require_torch_gpu, require_torch_accelerator,
run_test_in_subprocess, run_test_in_subprocess,
slow, slow,
torch_device, torch_device,
...@@ -703,17 +707,17 @@ class StableDiffusionMultiControlNetOneModelPipelineFastTests( ...@@ -703,17 +707,17 @@ class StableDiffusionMultiControlNetOneModelPipelineFastTests(
@slow @slow
@require_torch_gpu @require_torch_accelerator
class ControlNetPipelineSlowTests(unittest.TestCase): class ControlNetPipelineSlowTests(unittest.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def test_canny(self): def test_canny(self):
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny") controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
...@@ -721,7 +725,7 @@ class ControlNetPipelineSlowTests(unittest.TestCase): ...@@ -721,7 +725,7 @@ class ControlNetPipelineSlowTests(unittest.TestCase):
pipe = StableDiffusionControlNetPipeline.from_pretrained( pipe = StableDiffusionControlNetPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
) )
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0) generator = torch.Generator(device="cpu").manual_seed(0)
...@@ -748,7 +752,7 @@ class ControlNetPipelineSlowTests(unittest.TestCase): ...@@ -748,7 +752,7 @@ class ControlNetPipelineSlowTests(unittest.TestCase):
pipe = StableDiffusionControlNetPipeline.from_pretrained( pipe = StableDiffusionControlNetPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
) )
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0) generator = torch.Generator(device="cpu").manual_seed(0)
...@@ -775,7 +779,7 @@ class ControlNetPipelineSlowTests(unittest.TestCase): ...@@ -775,7 +779,7 @@ class ControlNetPipelineSlowTests(unittest.TestCase):
pipe = StableDiffusionControlNetPipeline.from_pretrained( pipe = StableDiffusionControlNetPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
) )
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0) generator = torch.Generator(device="cpu").manual_seed(0)
...@@ -802,7 +806,7 @@ class ControlNetPipelineSlowTests(unittest.TestCase): ...@@ -802,7 +806,7 @@ class ControlNetPipelineSlowTests(unittest.TestCase):
pipe = StableDiffusionControlNetPipeline.from_pretrained( pipe = StableDiffusionControlNetPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
) )
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0) generator = torch.Generator(device="cpu").manual_seed(0)
...@@ -829,7 +833,7 @@ class ControlNetPipelineSlowTests(unittest.TestCase): ...@@ -829,7 +833,7 @@ class ControlNetPipelineSlowTests(unittest.TestCase):
pipe = StableDiffusionControlNetPipeline.from_pretrained( pipe = StableDiffusionControlNetPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
) )
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0) generator = torch.Generator(device="cpu").manual_seed(0)
...@@ -856,7 +860,7 @@ class ControlNetPipelineSlowTests(unittest.TestCase): ...@@ -856,7 +860,7 @@ class ControlNetPipelineSlowTests(unittest.TestCase):
pipe = StableDiffusionControlNetPipeline.from_pretrained( pipe = StableDiffusionControlNetPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
) )
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0) generator = torch.Generator(device="cpu").manual_seed(0)
...@@ -883,7 +887,7 @@ class ControlNetPipelineSlowTests(unittest.TestCase): ...@@ -883,7 +887,7 @@ class ControlNetPipelineSlowTests(unittest.TestCase):
pipe = StableDiffusionControlNetPipeline.from_pretrained( pipe = StableDiffusionControlNetPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
) )
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(5) generator = torch.Generator(device="cpu").manual_seed(5)
...@@ -910,7 +914,7 @@ class ControlNetPipelineSlowTests(unittest.TestCase): ...@@ -910,7 +914,7 @@ class ControlNetPipelineSlowTests(unittest.TestCase):
pipe = StableDiffusionControlNetPipeline.from_pretrained( pipe = StableDiffusionControlNetPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
) )
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(5) generator = torch.Generator(device="cpu").manual_seed(5)
...@@ -932,9 +936,9 @@ class ControlNetPipelineSlowTests(unittest.TestCase): ...@@ -932,9 +936,9 @@ class ControlNetPipelineSlowTests(unittest.TestCase):
assert np.abs(expected_image - image).max() < 8e-2 assert np.abs(expected_image - image).max() < 8e-2
def test_sequential_cpu_offloading(self): def test_sequential_cpu_offloading(self):
torch.cuda.empty_cache() backend_empty_cache(torch_device)
torch.cuda.reset_max_memory_allocated() backend_reset_max_memory_allocated(torch_device)
torch.cuda.reset_peak_memory_stats() backend_reset_peak_memory_stats(torch_device)
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-seg") controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-seg")
...@@ -943,7 +947,7 @@ class ControlNetPipelineSlowTests(unittest.TestCase): ...@@ -943,7 +947,7 @@ class ControlNetPipelineSlowTests(unittest.TestCase):
) )
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
pipe.enable_attention_slicing() pipe.enable_attention_slicing()
pipe.enable_sequential_cpu_offload() pipe.enable_sequential_cpu_offload(device=torch_device)
prompt = "house" prompt = "house"
image = load_image( image = load_image(
...@@ -957,7 +961,7 @@ class ControlNetPipelineSlowTests(unittest.TestCase): ...@@ -957,7 +961,7 @@ class ControlNetPipelineSlowTests(unittest.TestCase):
output_type="np", output_type="np",
) )
mem_bytes = torch.cuda.max_memory_allocated() mem_bytes = backend_max_memory_allocated(torch_device)
# make sure that less than 7 GB is allocated # make sure that less than 7 GB is allocated
assert mem_bytes < 4 * 10**9 assert mem_bytes < 4 * 10**9
...@@ -967,7 +971,7 @@ class ControlNetPipelineSlowTests(unittest.TestCase): ...@@ -967,7 +971,7 @@ class ControlNetPipelineSlowTests(unittest.TestCase):
pipe = StableDiffusionControlNetPipeline.from_pretrained( pipe = StableDiffusionControlNetPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
) )
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0) generator = torch.Generator(device="cpu").manual_seed(0)
...@@ -1000,7 +1004,7 @@ class ControlNetPipelineSlowTests(unittest.TestCase): ...@@ -1000,7 +1004,7 @@ class ControlNetPipelineSlowTests(unittest.TestCase):
"stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
) )
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config) pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0) generator = torch.Generator(device="cpu").manual_seed(0)
...@@ -1041,7 +1045,7 @@ class ControlNetPipelineSlowTests(unittest.TestCase): ...@@ -1041,7 +1045,7 @@ class ControlNetPipelineSlowTests(unittest.TestCase):
pipe = StableDiffusionControlNetPipeline.from_pretrained( pipe = StableDiffusionControlNetPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
) )
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0) generator = torch.Generator(device="cpu").manual_seed(0)
...@@ -1068,17 +1072,17 @@ class ControlNetPipelineSlowTests(unittest.TestCase): ...@@ -1068,17 +1072,17 @@ class ControlNetPipelineSlowTests(unittest.TestCase):
@slow @slow
@require_torch_gpu @require_torch_accelerator
class StableDiffusionMultiControlNetPipelineSlowTests(unittest.TestCase): class StableDiffusionMultiControlNetPipelineSlowTests(unittest.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def test_pose_and_canny(self): def test_pose_and_canny(self):
controlnet_canny = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny") controlnet_canny = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny")
...@@ -1089,7 +1093,7 @@ class StableDiffusionMultiControlNetPipelineSlowTests(unittest.TestCase): ...@@ -1089,7 +1093,7 @@ class StableDiffusionMultiControlNetPipelineSlowTests(unittest.TestCase):
safety_checker=None, safety_checker=None,
controlnet=[controlnet_pose, controlnet_canny], controlnet=[controlnet_pose, controlnet_canny],
) )
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0) generator = torch.Generator(device="cpu").manual_seed(0)
......
...@@ -39,7 +39,7 @@ from diffusers.utils.testing_utils import ( ...@@ -39,7 +39,7 @@ from diffusers.utils.testing_utils import (
enable_full_determinism, enable_full_determinism,
floats_tensor, floats_tensor,
load_numpy, load_numpy,
require_torch_gpu, require_torch_accelerator,
slow, slow,
torch_device, torch_device,
) )
...@@ -393,7 +393,7 @@ class StableDiffusionMultiControlNetPipelineFastTests( ...@@ -393,7 +393,7 @@ class StableDiffusionMultiControlNetPipelineFastTests(
@slow @slow
@require_torch_gpu @require_torch_accelerator
class ControlNetImg2ImgPipelineSlowTests(unittest.TestCase): class ControlNetImg2ImgPipelineSlowTests(unittest.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
...@@ -411,7 +411,7 @@ class ControlNetImg2ImgPipelineSlowTests(unittest.TestCase): ...@@ -411,7 +411,7 @@ class ControlNetImg2ImgPipelineSlowTests(unittest.TestCase):
pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained( pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
) )
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0) generator = torch.Generator(device="cpu").manual_seed(0)
......
...@@ -40,7 +40,7 @@ from diffusers.utils.testing_utils import ( ...@@ -40,7 +40,7 @@ from diffusers.utils.testing_utils import (
floats_tensor, floats_tensor,
load_numpy, load_numpy,
numpy_cosine_similarity_distance, numpy_cosine_similarity_distance,
require_torch_gpu, require_torch_accelerator,
slow, slow,
torch_device, torch_device,
) )
...@@ -445,7 +445,7 @@ class MultiControlNetInpaintPipelineFastTests( ...@@ -445,7 +445,7 @@ class MultiControlNetInpaintPipelineFastTests(
@slow @slow
@require_torch_gpu @require_torch_accelerator
class ControlNetInpaintPipelineSlowTests(unittest.TestCase): class ControlNetInpaintPipelineSlowTests(unittest.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
...@@ -463,7 +463,7 @@ class ControlNetInpaintPipelineSlowTests(unittest.TestCase): ...@@ -463,7 +463,7 @@ class ControlNetInpaintPipelineSlowTests(unittest.TestCase):
pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained( pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
"botp/stable-diffusion-v1-5-inpainting", safety_checker=None, controlnet=controlnet "botp/stable-diffusion-v1-5-inpainting", safety_checker=None, controlnet=controlnet
) )
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0) generator = torch.Generator(device="cpu").manual_seed(0)
...@@ -509,7 +509,7 @@ class ControlNetInpaintPipelineSlowTests(unittest.TestCase): ...@@ -509,7 +509,7 @@ class ControlNetInpaintPipelineSlowTests(unittest.TestCase):
"stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet "stable-diffusion-v1-5/stable-diffusion-v1-5", safety_checker=None, controlnet=controlnet
) )
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(33) generator = torch.Generator(device="cpu").manual_seed(33)
......
...@@ -35,9 +35,10 @@ from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D ...@@ -35,9 +35,10 @@ from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2D
from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel from diffusers.pipelines.controlnet.pipeline_controlnet import MultiControlNetModel
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism, enable_full_determinism,
load_image, load_image,
require_torch_gpu, require_torch_accelerator,
slow, slow,
torch_device, torch_device,
) )
...@@ -212,7 +213,7 @@ class StableDiffusionXLControlNetPipelineFastTests( ...@@ -212,7 +213,7 @@ class StableDiffusionXLControlNetPipelineFastTests(
def test_save_load_optional_components(self): def test_save_load_optional_components(self):
self._test_save_load_optional_components() self._test_save_load_optional_components()
@require_torch_gpu @require_torch_accelerator
def test_stable_diffusion_xl_offloads(self): def test_stable_diffusion_xl_offloads(self):
pipes = [] pipes = []
components = self.get_dummy_components() components = self.get_dummy_components()
...@@ -893,17 +894,17 @@ class StableDiffusionXLMultiControlNetOneModelPipelineFastTests( ...@@ -893,17 +894,17 @@ class StableDiffusionXLMultiControlNetOneModelPipelineFastTests(
@slow @slow
@require_torch_gpu @require_torch_accelerator
class ControlNetSDXLPipelineSlowTests(unittest.TestCase): class ControlNetSDXLPipelineSlowTests(unittest.TestCase):
def setUp(self): def setUp(self):
super().setUp() super().setUp()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def test_canny(self): def test_canny(self):
controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-canny-sdxl-1.0") controlnet = ControlNetModel.from_pretrained("diffusers/controlnet-canny-sdxl-1.0")
...@@ -911,7 +912,7 @@ class ControlNetSDXLPipelineSlowTests(unittest.TestCase): ...@@ -911,7 +912,7 @@ class ControlNetSDXLPipelineSlowTests(unittest.TestCase):
pipe = StableDiffusionXLControlNetPipeline.from_pretrained( pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet
) )
pipe.enable_sequential_cpu_offload() pipe.enable_sequential_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0) generator = torch.Generator(device="cpu").manual_seed(0)
...@@ -934,7 +935,7 @@ class ControlNetSDXLPipelineSlowTests(unittest.TestCase): ...@@ -934,7 +935,7 @@ class ControlNetSDXLPipelineSlowTests(unittest.TestCase):
pipe = StableDiffusionXLControlNetPipeline.from_pretrained( pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet
) )
pipe.enable_sequential_cpu_offload() pipe.enable_sequential_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0) generator = torch.Generator(device="cpu").manual_seed(0)
......
...@@ -28,7 +28,12 @@ from diffusers import ( ...@@ -28,7 +28,12 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, require_torch_gpu, torch_device from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
require_torch_accelerator,
torch_device,
)
from ..pipeline_params import ( from ..pipeline_params import (
IMAGE_TO_IMAGE_IMAGE_PARAMS, IMAGE_TO_IMAGE_IMAGE_PARAMS,
...@@ -241,7 +246,7 @@ class ControlNetPipelineSDXLImg2ImgFastTests( ...@@ -241,7 +246,7 @@ class ControlNetPipelineSDXLImg2ImgFastTests(
def test_save_load_optional_components(self): def test_save_load_optional_components(self):
pass pass
@require_torch_gpu @require_torch_accelerator
def test_stable_diffusion_xl_offloads(self): def test_stable_diffusion_xl_offloads(self):
pipes = [] pipes = []
components = self.get_dummy_components() components = self.get_dummy_components()
...@@ -250,12 +255,12 @@ class ControlNetPipelineSDXLImg2ImgFastTests( ...@@ -250,12 +255,12 @@ class ControlNetPipelineSDXLImg2ImgFastTests(
components = self.get_dummy_components() components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components) sd_pipe = self.pipeline_class(**components)
sd_pipe.enable_model_cpu_offload() sd_pipe.enable_model_cpu_offload(device=torch_device)
pipes.append(sd_pipe) pipes.append(sd_pipe)
components = self.get_dummy_components() components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components) sd_pipe = self.pipeline_class(**components)
sd_pipe.enable_sequential_cpu_offload() sd_pipe.enable_sequential_cpu_offload(device=torch_device)
pipes.append(sd_pipe) pipes.append(sd_pipe)
image_slices = [] image_slices = []
......
...@@ -29,8 +29,9 @@ from diffusers import ( ...@@ -29,8 +29,9 @@ from diffusers import (
from diffusers.models import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel from diffusers.models import HunyuanDiT2DControlNetModel, HunyuanDiT2DMultiControlNetModel
from diffusers.utils import load_image from diffusers.utils import load_image
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism, enable_full_determinism,
require_torch_gpu, require_torch_accelerator,
slow, slow,
torch_device, torch_device,
) )
...@@ -178,19 +179,19 @@ class HunyuanDiTControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMix ...@@ -178,19 +179,19 @@ class HunyuanDiTControlNetPipelineFastTests(unittest.TestCase, PipelineTesterMix
@slow @slow
@require_torch_gpu @require_torch_accelerator
class HunyuanDiTControlNetPipelineSlowTests(unittest.TestCase): class HunyuanDiTControlNetPipelineSlowTests(unittest.TestCase):
pipeline_class = HunyuanDiTControlNetPipeline pipeline_class = HunyuanDiTControlNetPipeline
def setUp(self): def setUp(self):
super().setUp() super().setUp()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def test_canny(self): def test_canny(self):
controlnet = HunyuanDiT2DControlNetModel.from_pretrained( controlnet = HunyuanDiT2DControlNetModel.from_pretrained(
...@@ -199,7 +200,7 @@ class HunyuanDiTControlNetPipelineSlowTests(unittest.TestCase): ...@@ -199,7 +200,7 @@ class HunyuanDiTControlNetPipelineSlowTests(unittest.TestCase):
pipe = HunyuanDiTControlNetPipeline.from_pretrained( pipe = HunyuanDiTControlNetPipeline.from_pretrained(
"Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers", controlnet=controlnet, torch_dtype=torch.float16 "Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers", controlnet=controlnet, torch_dtype=torch.float16
) )
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0) generator = torch.Generator(device="cpu").manual_seed(0)
...@@ -238,7 +239,7 @@ class HunyuanDiTControlNetPipelineSlowTests(unittest.TestCase): ...@@ -238,7 +239,7 @@ class HunyuanDiTControlNetPipelineSlowTests(unittest.TestCase):
pipe = HunyuanDiTControlNetPipeline.from_pretrained( pipe = HunyuanDiTControlNetPipeline.from_pretrained(
"Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers", controlnet=controlnet, torch_dtype=torch.float16 "Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers", controlnet=controlnet, torch_dtype=torch.float16
) )
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0) generator = torch.Generator(device="cpu").manual_seed(0)
...@@ -277,7 +278,7 @@ class HunyuanDiTControlNetPipelineSlowTests(unittest.TestCase): ...@@ -277,7 +278,7 @@ class HunyuanDiTControlNetPipelineSlowTests(unittest.TestCase):
pipe = HunyuanDiTControlNetPipeline.from_pretrained( pipe = HunyuanDiTControlNetPipeline.from_pretrained(
"Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers", controlnet=controlnet, torch_dtype=torch.float16 "Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers", controlnet=controlnet, torch_dtype=torch.float16
) )
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0) generator = torch.Generator(device="cpu").manual_seed(0)
...@@ -318,7 +319,7 @@ class HunyuanDiTControlNetPipelineSlowTests(unittest.TestCase): ...@@ -318,7 +319,7 @@ class HunyuanDiTControlNetPipelineSlowTests(unittest.TestCase):
pipe = HunyuanDiTControlNetPipeline.from_pretrained( pipe = HunyuanDiTControlNetPipeline.from_pretrained(
"Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers", controlnet=controlnet, torch_dtype=torch.float16 "Tencent-Hunyuan/HunyuanDiT-v1.1-Diffusers", controlnet=controlnet, torch_dtype=torch.float16
) )
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0) generator = torch.Generator(device="cpu").manual_seed(0)
......
...@@ -34,13 +34,14 @@ from diffusers import ( ...@@ -34,13 +34,14 @@ from diffusers import (
) )
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism, enable_full_determinism,
is_torch_compile, is_torch_compile,
load_image, load_image,
load_numpy, load_numpy,
require_accelerator, require_accelerator,
require_torch_2, require_torch_2,
require_torch_gpu, require_torch_accelerator,
run_test_in_subprocess, run_test_in_subprocess,
slow, slow,
torch_device, torch_device,
...@@ -92,7 +93,7 @@ def _test_stable_diffusion_compile(in_queue, out_queue, timeout): ...@@ -92,7 +93,7 @@ def _test_stable_diffusion_compile(in_queue, out_queue, timeout):
safety_checker=None, safety_checker=None,
torch_dtype=torch.float16, torch_dtype=torch.float16,
) )
pipe.to("cuda") pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
pipe.unet.to(memory_format=torch.channels_last) pipe.unet.to(memory_format=torch.channels_last)
...@@ -334,12 +335,12 @@ class ControlNetXSPipelineFastTests( ...@@ -334,12 +335,12 @@ class ControlNetXSPipelineFastTests(
@slow @slow
@require_torch_gpu @require_torch_accelerator
class ControlNetXSPipelineSlowTests(unittest.TestCase): class ControlNetXSPipelineSlowTests(unittest.TestCase):
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def test_canny(self): def test_canny(self):
controlnet = ControlNetXSAdapter.from_pretrained( controlnet = ControlNetXSAdapter.from_pretrained(
...@@ -348,7 +349,7 @@ class ControlNetXSPipelineSlowTests(unittest.TestCase): ...@@ -348,7 +349,7 @@ class ControlNetXSPipelineSlowTests(unittest.TestCase):
pipe = StableDiffusionControlNetXSPipeline.from_pretrained( pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16 "stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16
) )
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0) generator = torch.Generator(device="cpu").manual_seed(0)
...@@ -374,7 +375,7 @@ class ControlNetXSPipelineSlowTests(unittest.TestCase): ...@@ -374,7 +375,7 @@ class ControlNetXSPipelineSlowTests(unittest.TestCase):
pipe = StableDiffusionControlNetXSPipeline.from_pretrained( pipe = StableDiffusionControlNetXSPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16 "stabilityai/stable-diffusion-2-1-base", controlnet=controlnet, torch_dtype=torch.float16
) )
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0) generator = torch.Generator(device="cpu").manual_seed(0)
......
...@@ -31,7 +31,14 @@ from diffusers import ( ...@@ -31,7 +31,14 @@ from diffusers import (
UNet2DConditionModel, UNet2DConditionModel,
) )
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import enable_full_determinism, load_image, require_torch_gpu, slow, torch_device from diffusers.utils.testing_utils import (
backend_empty_cache,
enable_full_determinism,
load_image,
require_torch_accelerator,
slow,
torch_device,
)
from diffusers.utils.torch_utils import randn_tensor from diffusers.utils.torch_utils import randn_tensor
from ...models.autoencoders.vae import ( from ...models.autoencoders.vae import (
...@@ -192,7 +199,7 @@ class StableDiffusionXLControlNetXSPipelineFastTests( ...@@ -192,7 +199,7 @@ class StableDiffusionXLControlNetXSPipelineFastTests(
def test_inference_batch_single_identical(self): def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=2e-3) self._test_inference_batch_single_identical(expected_max_diff=2e-3)
@require_torch_gpu @require_torch_accelerator
# Copied from test_controlnet_sdxl.py # Copied from test_controlnet_sdxl.py
def test_stable_diffusion_xl_offloads(self): def test_stable_diffusion_xl_offloads(self):
pipes = [] pipes = []
...@@ -202,12 +209,12 @@ class StableDiffusionXLControlNetXSPipelineFastTests( ...@@ -202,12 +209,12 @@ class StableDiffusionXLControlNetXSPipelineFastTests(
components = self.get_dummy_components() components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components) sd_pipe = self.pipeline_class(**components)
sd_pipe.enable_model_cpu_offload() sd_pipe.enable_model_cpu_offload(device=torch_device)
pipes.append(sd_pipe) pipes.append(sd_pipe)
components = self.get_dummy_components() components = self.get_dummy_components()
sd_pipe = self.pipeline_class(**components) sd_pipe = self.pipeline_class(**components)
sd_pipe.enable_sequential_cpu_offload() sd_pipe.enable_sequential_cpu_offload(device=torch_device)
pipes.append(sd_pipe) pipes.append(sd_pipe)
image_slices = [] image_slices = []
...@@ -369,12 +376,12 @@ class StableDiffusionXLControlNetXSPipelineFastTests( ...@@ -369,12 +376,12 @@ class StableDiffusionXLControlNetXSPipelineFastTests(
@slow @slow
@require_torch_gpu @require_torch_accelerator
class StableDiffusionXLControlNetXSPipelineSlowTests(unittest.TestCase): class StableDiffusionXLControlNetXSPipelineSlowTests(unittest.TestCase):
def tearDown(self): def tearDown(self):
super().tearDown() super().tearDown()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def test_canny(self): def test_canny(self):
controlnet = ControlNetXSAdapter.from_pretrained( controlnet = ControlNetXSAdapter.from_pretrained(
...@@ -383,7 +390,7 @@ class StableDiffusionXLControlNetXSPipelineSlowTests(unittest.TestCase): ...@@ -383,7 +390,7 @@ class StableDiffusionXLControlNetXSPipelineSlowTests(unittest.TestCase):
pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained( pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16 "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
) )
pipe.enable_sequential_cpu_offload() pipe.enable_sequential_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0) generator = torch.Generator(device="cpu").manual_seed(0)
...@@ -407,7 +414,7 @@ class StableDiffusionXLControlNetXSPipelineSlowTests(unittest.TestCase): ...@@ -407,7 +414,7 @@ class StableDiffusionXLControlNetXSPipelineSlowTests(unittest.TestCase):
pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained( pipe = StableDiffusionXLControlNetXSPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16 "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
) )
pipe.enable_sequential_cpu_offload() pipe.enable_sequential_cpu_offload(device=torch_device)
pipe.set_progress_bar_config(disable=None) pipe.set_progress_bar_config(disable=None)
generator = torch.Generator(device="cpu").manual_seed(0) generator = torch.Generator(device="cpu").manual_seed(0)
......
...@@ -19,7 +19,7 @@ import numpy as np ...@@ -19,7 +19,7 @@ import numpy as np
import torch import torch
from diffusers import DDIMPipeline, DDIMScheduler, UNet2DModel from diffusers import DDIMPipeline, DDIMScheduler, UNet2DModel
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow, torch_device from diffusers.utils.testing_utils import enable_full_determinism, require_torch_accelerator, slow, torch_device
from ..pipeline_params import UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS, UNCONDITIONAL_IMAGE_GENERATION_PARAMS from ..pipeline_params import UNCONDITIONAL_IMAGE_GENERATION_BATCH_PARAMS, UNCONDITIONAL_IMAGE_GENERATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin from ..test_pipelines_common import PipelineTesterMixin
...@@ -99,7 +99,7 @@ class DDIMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -99,7 +99,7 @@ class DDIMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
@slow @slow
@require_torch_gpu @require_torch_accelerator
class DDIMPipelineIntegrationTests(unittest.TestCase): class DDIMPipelineIntegrationTests(unittest.TestCase):
def test_inference_cifar10(self): def test_inference_cifar10(self):
model_id = "google/ddpm-cifar10-32" model_id = "google/ddpm-cifar10-32"
......
...@@ -19,7 +19,7 @@ import numpy as np ...@@ -19,7 +19,7 @@ import numpy as np
import torch import torch
from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel from diffusers import DDPMPipeline, DDPMScheduler, UNet2DModel
from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, slow, torch_device from diffusers.utils.testing_utils import enable_full_determinism, require_torch_accelerator, slow, torch_device
enable_full_determinism() enable_full_determinism()
...@@ -88,7 +88,7 @@ class DDPMPipelineFastTests(unittest.TestCase): ...@@ -88,7 +88,7 @@ class DDPMPipelineFastTests(unittest.TestCase):
@slow @slow
@require_torch_gpu @require_torch_accelerator
class DDPMPipelineIntegrationTests(unittest.TestCase): class DDPMPipelineIntegrationTests(unittest.TestCase):
def test_inference_cifar10(self): def test_inference_cifar10(self):
model_id = "google/ddpm-cifar10-32" model_id = "google/ddpm-cifar10-32"
......
...@@ -24,10 +24,13 @@ from diffusers import ( ...@@ -24,10 +24,13 @@ from diffusers import (
from diffusers.models.attention_processor import AttnAddedKVProcessor from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
backend_empty_cache,
backend_reset_max_memory_allocated,
backend_reset_peak_memory_stats,
load_numpy, load_numpy,
require_accelerator, require_accelerator,
require_hf_hub_version_greater, require_hf_hub_version_greater,
require_torch_gpu, require_torch_accelerator,
require_transformers_version_greater, require_transformers_version_greater,
skip_mps, skip_mps,
slow, slow,
...@@ -98,28 +101,28 @@ class IFPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, unittest.T ...@@ -98,28 +101,28 @@ class IFPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, unittest.T
@slow @slow
@require_torch_gpu @require_torch_accelerator
class IFPipelineSlowTests(unittest.TestCase): class IFPipelineSlowTests(unittest.TestCase):
def setUp(self): def setUp(self):
# clean up the VRAM before each test # clean up the VRAM before each test
super().setUp() super().setUp()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def tearDown(self): def tearDown(self):
# clean up the VRAM after each test # clean up the VRAM after each test
super().tearDown() super().tearDown()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def test_if_text_to_image(self): def test_if_text_to_image(self):
pipe = IFPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16) pipe = IFPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
pipe.unet.set_attn_processor(AttnAddedKVProcessor()) pipe.unet.set_attn_processor(AttnAddedKVProcessor())
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload(device=torch_device)
torch.cuda.reset_max_memory_allocated() backend_reset_max_memory_allocated(torch_device)
torch.cuda.empty_cache() backend_empty_cache(torch_device)
torch.cuda.reset_peak_memory_stats() backend_reset_peak_memory_stats(torch_device)
generator = torch.Generator(device="cpu").manual_seed(0) generator = torch.Generator(device="cpu").manual_seed(0)
output = pipe( output = pipe(
......
...@@ -23,11 +23,14 @@ from diffusers import IFImg2ImgPipeline ...@@ -23,11 +23,14 @@ from diffusers import IFImg2ImgPipeline
from diffusers.models.attention_processor import AttnAddedKVProcessor from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
backend_empty_cache,
backend_reset_max_memory_allocated,
backend_reset_peak_memory_stats,
floats_tensor, floats_tensor,
load_numpy, load_numpy,
require_accelerator, require_accelerator,
require_hf_hub_version_greater, require_hf_hub_version_greater,
require_torch_gpu, require_torch_accelerator,
require_transformers_version_greater, require_transformers_version_greater,
skip_mps, skip_mps,
slow, slow,
...@@ -109,19 +112,19 @@ class IFImg2ImgPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, uni ...@@ -109,19 +112,19 @@ class IFImg2ImgPipelineFastTests(PipelineTesterMixin, IFPipelineTesterMixin, uni
@slow @slow
@require_torch_gpu @require_torch_accelerator
class IFImg2ImgPipelineSlowTests(unittest.TestCase): class IFImg2ImgPipelineSlowTests(unittest.TestCase):
def setUp(self): def setUp(self):
# clean up the VRAM before each test # clean up the VRAM before each test
super().setUp() super().setUp()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def tearDown(self): def tearDown(self):
# clean up the VRAM after each test # clean up the VRAM after each test
super().tearDown() super().tearDown()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def test_if_img2img(self): def test_if_img2img(self):
pipe = IFImg2ImgPipeline.from_pretrained( pipe = IFImg2ImgPipeline.from_pretrained(
...@@ -130,11 +133,11 @@ class IFImg2ImgPipelineSlowTests(unittest.TestCase): ...@@ -130,11 +133,11 @@ class IFImg2ImgPipelineSlowTests(unittest.TestCase):
torch_dtype=torch.float16, torch_dtype=torch.float16,
) )
pipe.unet.set_attn_processor(AttnAddedKVProcessor()) pipe.unet.set_attn_processor(AttnAddedKVProcessor())
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload(device=torch_device)
torch.cuda.reset_max_memory_allocated() backend_reset_max_memory_allocated(torch_device)
torch.cuda.empty_cache() backend_empty_cache(torch_device)
torch.cuda.reset_peak_memory_stats() backend_reset_peak_memory_stats(torch_device)
image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device) image = floats_tensor((1, 3, 64, 64), rng=random.Random(0)).to(torch_device)
generator = torch.Generator(device="cpu").manual_seed(0) generator = torch.Generator(device="cpu").manual_seed(0)
......
...@@ -23,11 +23,15 @@ from diffusers import IFImg2ImgSuperResolutionPipeline ...@@ -23,11 +23,15 @@ from diffusers import IFImg2ImgSuperResolutionPipeline
from diffusers.models.attention_processor import AttnAddedKVProcessor from diffusers.models.attention_processor import AttnAddedKVProcessor
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
backend_empty_cache,
backend_max_memory_allocated,
backend_reset_max_memory_allocated,
backend_reset_peak_memory_stats,
floats_tensor, floats_tensor,
load_numpy, load_numpy,
require_accelerator, require_accelerator,
require_hf_hub_version_greater, require_hf_hub_version_greater,
require_torch_gpu, require_torch_accelerator,
require_transformers_version_greater, require_transformers_version_greater,
skip_mps, skip_mps,
slow, slow,
...@@ -106,19 +110,19 @@ class IFImg2ImgSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineT ...@@ -106,19 +110,19 @@ class IFImg2ImgSuperResolutionPipelineFastTests(PipelineTesterMixin, IFPipelineT
@slow @slow
@require_torch_gpu @require_torch_accelerator
class IFImg2ImgSuperResolutionPipelineSlowTests(unittest.TestCase): class IFImg2ImgSuperResolutionPipelineSlowTests(unittest.TestCase):
def setUp(self): def setUp(self):
# clean up the VRAM before each test # clean up the VRAM before each test
super().setUp() super().setUp()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def tearDown(self): def tearDown(self):
# clean up the VRAM after each test # clean up the VRAM after each test
super().tearDown() super().tearDown()
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
def test_if_img2img_superresolution(self): def test_if_img2img_superresolution(self):
pipe = IFImg2ImgSuperResolutionPipeline.from_pretrained( pipe = IFImg2ImgSuperResolutionPipeline.from_pretrained(
...@@ -127,11 +131,11 @@ class IFImg2ImgSuperResolutionPipelineSlowTests(unittest.TestCase): ...@@ -127,11 +131,11 @@ class IFImg2ImgSuperResolutionPipelineSlowTests(unittest.TestCase):
torch_dtype=torch.float16, torch_dtype=torch.float16,
) )
pipe.unet.set_attn_processor(AttnAddedKVProcessor()) pipe.unet.set_attn_processor(AttnAddedKVProcessor())
pipe.enable_model_cpu_offload() pipe.enable_model_cpu_offload(device=torch_device)
torch.cuda.reset_max_memory_allocated() backend_reset_max_memory_allocated(torch_device)
torch.cuda.empty_cache() backend_empty_cache(torch_device)
torch.cuda.reset_peak_memory_stats() backend_reset_peak_memory_stats(torch_device)
generator = torch.Generator(device="cpu").manual_seed(0) generator = torch.Generator(device="cpu").manual_seed(0)
...@@ -151,7 +155,8 @@ class IFImg2ImgSuperResolutionPipelineSlowTests(unittest.TestCase): ...@@ -151,7 +155,8 @@ class IFImg2ImgSuperResolutionPipelineSlowTests(unittest.TestCase):
assert image.shape == (256, 256, 3) assert image.shape == (256, 256, 3)
mem_bytes = torch.cuda.max_memory_allocated() mem_bytes = backend_max_memory_allocated(torch_device)
assert mem_bytes < 12 * 10**9 assert mem_bytes < 12 * 10**9
expected_image = load_numpy( expected_image = load_numpy(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment