Unverified Commit 41e4779d authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[LoRA] fix: lora loading when using with a device_mapped model. (#9449)



* fix: lora loading when using with a device_mapped model.

* better attibutung

* empty
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>

* Apply suggestions from code review
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>

* minors

* better error messages.

* fix-copies

* add: tests, docs.

* add hardware note.

* quality

* Update docs/source/en/training/distributed_inference.md
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>

* fixes

* skip properly.

* fixes

---------
Co-authored-by: default avatarBenjamin Bossan <BenjaminBossan@users.noreply.github.com>
Co-authored-by: default avatarMarc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: default avatarSteven Liu <59462357+stevhliu@users.noreply.github.com>
parent ff182ad6
...@@ -237,3 +237,5 @@ with torch.no_grad(): ...@@ -237,3 +237,5 @@ with torch.no_grad():
``` ```
By selectively loading and unloading the models you need at a given stage and sharding the largest models across multiple GPUs, it is possible to run inference with large models on consumer GPUs. By selectively loading and unloading the models you need at a given stage and sharding the largest models across multiple GPUs, it is possible to run inference with large models on consumer GPUs.
This workflow is also compatible with LoRAs via [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. However, only LoRAs without text encoder components are currently supported in this workflow.
...@@ -31,6 +31,7 @@ from ..utils import ( ...@@ -31,6 +31,7 @@ from ..utils import (
delete_adapter_layers, delete_adapter_layers,
deprecate, deprecate,
is_accelerate_available, is_accelerate_available,
is_accelerate_version,
is_peft_available, is_peft_available,
is_transformers_available, is_transformers_available,
logging, logging,
...@@ -214,9 +215,18 @@ class LoraBaseMixin: ...@@ -214,9 +215,18 @@ class LoraBaseMixin:
is_model_cpu_offload = False is_model_cpu_offload = False
is_sequential_cpu_offload = False is_sequential_cpu_offload = False
def model_has_device_map(model):
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
return False
return getattr(model, "hf_device_map", None) is not None
if _pipeline is not None and _pipeline.hf_device_map is None: if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items(): for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): if (
isinstance(component, nn.Module)
and hasattr(component, "_hf_hook")
and not model_has_device_map(component)
):
if not is_model_cpu_offload: if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload: if not is_sequential_cpu_offload:
......
...@@ -39,6 +39,7 @@ from ..utils import ( ...@@ -39,6 +39,7 @@ from ..utils import (
get_adapter_name, get_adapter_name,
get_peft_kwargs, get_peft_kwargs,
is_accelerate_available, is_accelerate_available,
is_accelerate_version,
is_peft_version, is_peft_version,
is_torch_version, is_torch_version,
logging, logging,
...@@ -398,9 +399,18 @@ class UNet2DConditionLoadersMixin: ...@@ -398,9 +399,18 @@ class UNet2DConditionLoadersMixin:
is_model_cpu_offload = False is_model_cpu_offload = False
is_sequential_cpu_offload = False is_sequential_cpu_offload = False
def model_has_device_map(model):
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
return False
return getattr(model, "hf_device_map", None) is not None
if _pipeline is not None and _pipeline.hf_device_map is None: if _pipeline is not None and _pipeline.hf_device_map is None:
for _, component in _pipeline.components.items(): for _, component in _pipeline.components.items():
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"): if (
isinstance(component, nn.Module)
and hasattr(component, "_hf_hook")
and not model_has_device_map(component)
):
if not is_model_cpu_offload: if not is_model_cpu_offload:
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload) is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
if not is_sequential_cpu_offload: if not is_sequential_cpu_offload:
......
...@@ -36,6 +36,7 @@ from ..utils import ( ...@@ -36,6 +36,7 @@ from ..utils import (
deprecate, deprecate,
get_class_from_dynamic_module, get_class_from_dynamic_module,
is_accelerate_available, is_accelerate_available,
is_accelerate_version,
is_peft_available, is_peft_available,
is_transformers_available, is_transformers_available,
logging, logging,
...@@ -947,3 +948,9 @@ def _get_ignore_patterns( ...@@ -947,3 +948,9 @@ def _get_ignore_patterns(
) )
return ignore_patterns return ignore_patterns
def model_has_device_map(model):
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
return False
return getattr(model, "hf_device_map", None) is not None
...@@ -85,6 +85,7 @@ from .pipeline_loading_utils import ( ...@@ -85,6 +85,7 @@ from .pipeline_loading_utils import (
_update_init_kwargs_with_connected_pipeline, _update_init_kwargs_with_connected_pipeline,
load_sub_model, load_sub_model,
maybe_raise_or_warn, maybe_raise_or_warn,
model_has_device_map,
variant_compatible_siblings, variant_compatible_siblings,
warn_deprecated_model_variant, warn_deprecated_model_variant,
) )
...@@ -406,6 +407,16 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -406,6 +407,16 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload) return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, accelerate.hooks.CpuOffload)
# device-mapped modules should not go through any device placements.
device_mapped_components = [
key for key, component in self.components.items() if model_has_device_map(component)
]
if device_mapped_components:
raise ValueError(
"The following pipeline components have been found to use a device map: "
f"{device_mapped_components}. This is incompatible with explicitly setting the device using `to()`."
)
# .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer # .to("cuda") would raise an error if the pipeline is sequentially offloaded, so we raise our own to make it clearer
pipeline_is_sequentially_offloaded = any( pipeline_is_sequentially_offloaded = any(
module_is_sequentially_offloaded(module) for _, module in self.components.items() module_is_sequentially_offloaded(module) for _, module in self.components.items()
...@@ -1002,6 +1013,16 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1002,6 +1013,16 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
default to "cuda". default to "cuda".
""" """
# device-mapped modules should not go through any device placements.
device_mapped_components = [
key for key, component in self.components.items() if model_has_device_map(component)
]
if device_mapped_components:
raise ValueError(
"The following pipeline components have been found to use a device map: "
f"{device_mapped_components}. This is incompatible with `enable_model_cpu_offload()`."
)
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1 is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
if is_pipeline_device_mapped: if is_pipeline_device_mapped:
raise ValueError( raise ValueError(
...@@ -1104,6 +1125,16 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1104,6 +1125,16 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
default to "cuda". default to "cuda".
""" """
# device-mapped modules should not go through any device placements.
device_mapped_components = [
key for key, component in self.components.items() if model_has_device_map(component)
]
if device_mapped_components:
raise ValueError(
"The following pipeline components have been found to use a device map: "
f"{device_mapped_components}. This is incompatible with `enable_sequential_cpu_offload()`."
)
if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
from accelerate import cpu_offload from accelerate import cpu_offload
else: else:
......
...@@ -506,9 +506,14 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -506,9 +506,14 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")} model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values())) self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values()))
@unittest.skip("Test currently not supported.")
def test_sequential_cpu_offload_forward_pass(self): def test_sequential_cpu_offload_forward_pass(self):
pass pass
@unittest.skip("Test currently not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@nightly @nightly
class AudioLDM2PipelineSlowTests(unittest.TestCase): class AudioLDM2PipelineSlowTests(unittest.TestCase):
......
...@@ -514,6 +514,18 @@ class StableDiffusionMultiControlNetPipelineFastTests( ...@@ -514,6 +514,18 @@ class StableDiffusionMultiControlNetPipelineFastTests(
assert image.shape == (4, 64, 64, 3) assert image.shape == (4, 64, 64, 3)
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass
class StableDiffusionMultiControlNetOneModelPipelineFastTests( class StableDiffusionMultiControlNetOneModelPipelineFastTests(
IPAdapterTesterMixin, PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase IPAdapterTesterMixin, PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase
...@@ -697,6 +709,18 @@ class StableDiffusionMultiControlNetOneModelPipelineFastTests( ...@@ -697,6 +709,18 @@ class StableDiffusionMultiControlNetOneModelPipelineFastTests(
except NotImplementedError: except NotImplementedError:
pass pass
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass
@slow @slow
@require_torch_gpu @require_torch_gpu
......
...@@ -389,6 +389,18 @@ class StableDiffusionMultiControlNetPipelineFastTests( ...@@ -389,6 +389,18 @@ class StableDiffusionMultiControlNetPipelineFastTests(
except NotImplementedError: except NotImplementedError:
pass pass
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass
@slow @slow
@require_torch_gpu @require_torch_gpu
......
...@@ -441,6 +441,18 @@ class MultiControlNetInpaintPipelineFastTests( ...@@ -441,6 +441,18 @@ class MultiControlNetInpaintPipelineFastTests(
except NotImplementedError: except NotImplementedError:
pass pass
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass
@slow @slow
@require_torch_gpu @require_torch_gpu
......
...@@ -683,6 +683,18 @@ class StableDiffusionXLMultiControlNetPipelineFastTests( ...@@ -683,6 +683,18 @@ class StableDiffusionXLMultiControlNetPipelineFastTests(
def test_save_load_optional_components(self): def test_save_load_optional_components(self):
return self._test_save_load_optional_components() return self._test_save_load_optional_components()
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass
class StableDiffusionXLMultiControlNetOneModelPipelineFastTests( class StableDiffusionXLMultiControlNetOneModelPipelineFastTests(
PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase
...@@ -887,6 +899,18 @@ class StableDiffusionXLMultiControlNetOneModelPipelineFastTests( ...@@ -887,6 +899,18 @@ class StableDiffusionXLMultiControlNetOneModelPipelineFastTests(
self.assertTrue(np.abs(image_slice_without_neg_cond - image_slice_with_neg_cond).max() > 1e-2) self.assertTrue(np.abs(image_slice_without_neg_cond - image_slice_with_neg_cond).max() > 1e-2)
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass
@slow @slow
@require_torch_gpu @require_torch_gpu
......
...@@ -8,9 +8,11 @@ from huggingface_hub import hf_hub_download ...@@ -8,9 +8,11 @@ from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel from transformers import AutoTokenizer, CLIPTextConfig, CLIPTextModel, CLIPTokenizer, T5EncoderModel
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
from diffusers.image_processor import VaeImageProcessor
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
numpy_cosine_similarity_distance, numpy_cosine_similarity_distance,
require_big_gpu_with_torch_cuda, require_big_gpu_with_torch_cuda,
require_torch_multi_gpu,
slow, slow,
torch_device, torch_device,
) )
...@@ -282,3 +284,172 @@ class FluxPipelineSlowTests(unittest.TestCase): ...@@ -282,3 +284,172 @@ class FluxPipelineSlowTests(unittest.TestCase):
max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten()) max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())
assert max_diff < 1e-4 assert max_diff < 1e-4
@require_torch_multi_gpu
@torch.no_grad()
def test_flux_component_sharding(self):
"""
internal note: test was run on `audace`.
"""
ckpt_id = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16
prompt = "a photo of a cat with tiger-like look"
pipeline = FluxPipeline.from_pretrained(
ckpt_id,
transformer=None,
vae=None,
device_map="balanced",
max_memory={0: "16GB", 1: "16GB"},
torch_dtype=dtype,
)
prompt_embeds, pooled_prompt_embeds, _ = pipeline.encode_prompt(
prompt=prompt, prompt_2=None, max_sequence_length=512
)
del pipeline.text_encoder
del pipeline.text_encoder_2
del pipeline.tokenizer
del pipeline.tokenizer_2
del pipeline
gc.collect()
torch.cuda.empty_cache()
transformer = FluxTransformer2DModel.from_pretrained(
ckpt_id, subfolder="transformer", device_map="auto", max_memory={0: "16GB", 1: "16GB"}, torch_dtype=dtype
)
pipeline = FluxPipeline.from_pretrained(
ckpt_id,
text_encoder=None,
text_encoder_2=None,
tokenizer=None,
tokenizer_2=None,
vae=None,
transformer=transformer,
torch_dtype=dtype,
)
height, width = 768, 1360
# No need to wrap it up under `torch.no_grad()` as pipeline call method
# is already wrapped under that.
latents = pipeline(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
num_inference_steps=10,
guidance_scale=3.5,
height=height,
width=width,
output_type="latent",
generator=torch.manual_seed(0),
).images
latent_slice = latents[0, :3, :3].flatten().float().cpu().numpy()
expected_slice = np.array([-0.377, -0.3008, -0.5117, -0.252, 0.0615, -0.3477, -0.1309, -0.1914, 0.1533])
assert numpy_cosine_similarity_distance(latent_slice, expected_slice) < 1e-4
del pipeline.transformer
del pipeline
gc.collect()
torch.cuda.empty_cache()
vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=dtype).to(torch_device)
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor)
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
image = vae.decode(latents, return_dict=False)[0]
image = image_processor.postprocess(image, output_type="np")
image_slice = image[0, :3, :3, -1].flatten()
expected_slice = np.array([0.127, 0.1113, 0.1055, 0.1172, 0.1172, 0.1074, 0.1191, 0.1191, 0.1152])
assert numpy_cosine_similarity_distance(image_slice, expected_slice) < 1e-4
@require_torch_multi_gpu
@torch.no_grad()
def test_flux_component_sharding_with_lora(self):
"""
internal note: test was run on `audace`.
"""
ckpt_id = "black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16
prompt = "jon snow eating pizza."
pipeline = FluxPipeline.from_pretrained(
ckpt_id,
transformer=None,
vae=None,
device_map="balanced",
max_memory={0: "16GB", 1: "16GB"},
torch_dtype=dtype,
)
prompt_embeds, pooled_prompt_embeds, _ = pipeline.encode_prompt(
prompt=prompt, prompt_2=None, max_sequence_length=512
)
del pipeline.text_encoder
del pipeline.text_encoder_2
del pipeline.tokenizer
del pipeline.tokenizer_2
del pipeline
gc.collect()
torch.cuda.empty_cache()
transformer = FluxTransformer2DModel.from_pretrained(
ckpt_id, subfolder="transformer", device_map="auto", max_memory={0: "16GB", 1: "16GB"}, torch_dtype=dtype
)
pipeline = FluxPipeline.from_pretrained(
ckpt_id,
text_encoder=None,
text_encoder_2=None,
tokenizer=None,
tokenizer_2=None,
vae=None,
transformer=transformer,
torch_dtype=dtype,
)
pipeline.load_lora_weights("TheLastBen/Jon_Snow_Flux_LoRA", weight_name="jon_snow.safetensors")
height, width = 768, 1360
# No need to wrap it up under `torch.no_grad()` as pipeline call method
# is already wrapped under that.
latents = pipeline(
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
num_inference_steps=10,
guidance_scale=3.5,
height=height,
width=width,
output_type="latent",
generator=torch.manual_seed(0),
).images
latent_slice = latents[0, :3, :3].flatten().float().cpu().numpy()
expected_slice = np.array([-0.6523, -0.4961, -0.9141, -0.5, -0.2129, -0.6914, -0.375, -0.5664, -0.1699])
assert numpy_cosine_similarity_distance(latent_slice, expected_slice) < 1e-4
del pipeline.transformer
del pipeline
gc.collect()
torch.cuda.empty_cache()
vae = AutoencoderKL.from_pretrained(ckpt_id, subfolder="vae", torch_dtype=dtype).to(torch_device)
vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1)
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor)
latents = FluxPipeline._unpack_latents(latents, height, width, vae_scale_factor)
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
image = vae.decode(latents, return_dict=False)[0]
image = image_processor.postprocess(image, output_type="np")
image_slice = image[0, :3, :3, -1].flatten()
expected_slice = np.array([0.1211, 0.1094, 0.1035, 0.1094, 0.1113, 0.1074, 0.1133, 0.1133, 0.1094])
assert numpy_cosine_similarity_distance(image_slice, expected_slice) < 1e-4
...@@ -139,6 +139,18 @@ class KandinskyPipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCase) ...@@ -139,6 +139,18 @@ class KandinskyPipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCase)
def test_dict_tuple_outputs_equivalent(self): def test_dict_tuple_outputs_equivalent(self):
super().test_dict_tuple_outputs_equivalent(expected_max_difference=5e-4) super().test_dict_tuple_outputs_equivalent(expected_max_difference=5e-4)
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass
class KandinskyPipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.TestCase): class KandinskyPipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = KandinskyImg2ImgCombinedPipeline pipeline_class = KandinskyImg2ImgCombinedPipeline
...@@ -248,6 +260,18 @@ class KandinskyPipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.Te ...@@ -248,6 +260,18 @@ class KandinskyPipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.Te
def test_save_load_optional_components(self): def test_save_load_optional_components(self):
super().test_save_load_optional_components(expected_max_difference=5e-4) super().test_save_load_optional_components(expected_max_difference=5e-4)
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass
class KandinskyPipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest.TestCase): class KandinskyPipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = KandinskyInpaintCombinedPipeline pipeline_class = KandinskyInpaintCombinedPipeline
...@@ -363,3 +387,15 @@ class KandinskyPipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest.Te ...@@ -363,3 +387,15 @@ class KandinskyPipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest.Te
def test_save_load_local(self): def test_save_load_local(self):
super().test_save_load_local(expected_max_difference=5e-3) super().test_save_load_local(expected_max_difference=5e-3)
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass
...@@ -159,6 +159,18 @@ class KandinskyV22PipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCa ...@@ -159,6 +159,18 @@ class KandinskyV22PipelineCombinedFastTests(PipelineTesterMixin, unittest.TestCa
def test_callback_cfg(self): def test_callback_cfg(self):
pass pass
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass
class KandinskyV22PipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.TestCase): class KandinskyV22PipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = KandinskyV22Img2ImgCombinedPipeline pipeline_class = KandinskyV22Img2ImgCombinedPipeline
...@@ -281,6 +293,18 @@ class KandinskyV22PipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest ...@@ -281,6 +293,18 @@ class KandinskyV22PipelineImg2ImgCombinedFastTests(PipelineTesterMixin, unittest
def test_callback_cfg(self): def test_callback_cfg(self):
pass pass
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass
class KandinskyV22PipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest.TestCase): class KandinskyV22PipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = KandinskyV22InpaintCombinedPipeline pipeline_class = KandinskyV22InpaintCombinedPipeline
...@@ -404,3 +428,15 @@ class KandinskyV22PipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest ...@@ -404,3 +428,15 @@ class KandinskyV22PipelineInpaintCombinedFastTests(PipelineTesterMixin, unittest
def test_callback_cfg(self): def test_callback_cfg(self):
pass pass
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass
...@@ -404,6 +404,10 @@ class MusicLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): ...@@ -404,6 +404,10 @@ class MusicLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")} model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values())) self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values()))
@unittest.skip("Test currently not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@nightly @nightly
@require_torch_gpu @require_torch_gpu
......
...@@ -279,3 +279,15 @@ class StableCascadeCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestC ...@@ -279,3 +279,15 @@ class StableCascadeCombinedPipelineFastTests(PipelineTesterMixin, unittest.TestC
) )
assert np.abs(output_prompt.images - output_prompt_embeds.images).max() < 1e-5 assert np.abs(output_prompt.images - output_prompt_embeds.images).max() < 1e-5
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass
...@@ -593,6 +593,18 @@ class StableDiffusionMultiAdapterPipelineFastTests(AdapterTests, PipelineTesterM ...@@ -593,6 +593,18 @@ class StableDiffusionMultiAdapterPipelineFastTests(AdapterTests, PipelineTesterM
if test_mean_pixel_difference: if test_mean_pixel_difference:
assert_mean_pixel_difference(output_batch[0][0], output[0][0]) assert_mean_pixel_difference(output_batch[0][0], output[0][0])
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass
@slow @slow
@require_torch_gpu @require_torch_gpu
......
...@@ -642,9 +642,6 @@ class StableDiffusionXLMultiAdapterPipelineFastTests( ...@@ -642,9 +642,6 @@ class StableDiffusionXLMultiAdapterPipelineFastTests(
assert image.shape == (1, 64, 64, 3) assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5313, 0.5375, 0.4942, 0.5021, 0.6142, 0.4968, 0.5434, 0.5311, 0.5448]) expected_slice = np.array([0.5313, 0.5375, 0.4942, 0.5021, 0.6142, 0.4968, 0.5434, 0.5311, 0.5448])
debug = [str(round(i, 4)) for i in image_slice.flatten().tolist()]
print(",".join(debug))
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_adapter_sdxl_lcm_custom_timesteps(self): def test_adapter_sdxl_lcm_custom_timesteps(self):
...@@ -667,7 +664,16 @@ class StableDiffusionXLMultiAdapterPipelineFastTests( ...@@ -667,7 +664,16 @@ class StableDiffusionXLMultiAdapterPipelineFastTests(
assert image.shape == (1, 64, 64, 3) assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.5313, 0.5375, 0.4942, 0.5021, 0.6142, 0.4968, 0.5434, 0.5311, 0.5448]) expected_slice = np.array([0.5313, 0.5375, 0.4942, 0.5021, 0.6142, 0.4968, 0.5434, 0.5311, 0.5448])
debug = [str(round(i, 4)) for i in image_slice.flatten().tolist()]
print(",".join(debug))
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass
...@@ -184,6 +184,18 @@ class StableUnCLIPPipelineFastTests( ...@@ -184,6 +184,18 @@ class StableUnCLIPPipelineFastTests(
def test_inference_batch_single_identical(self): def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(expected_max_diff=1e-3) self._test_inference_batch_single_identical(expected_max_diff=1e-3)
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass
@nightly @nightly
@require_torch_gpu @require_torch_gpu
......
...@@ -205,6 +205,18 @@ class StableUnCLIPImg2ImgPipelineFastTests( ...@@ -205,6 +205,18 @@ class StableUnCLIPImg2ImgPipelineFastTests(
def test_xformers_attention_forwardGenerator_pass(self): def test_xformers_attention_forwardGenerator_pass(self):
self._test_xformers_attention_forwardGenerator_pass(test_max_difference=False) self._test_xformers_attention_forwardGenerator_pass(test_max_difference=False)
@unittest.skip("Test not supported.")
def test_calling_mco_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_to_raises_error_device_mapped_components(self):
pass
@unittest.skip("Test not supported.")
def test_calling_sco_raises_error_device_mapped_components(self):
pass
@nightly @nightly
@require_torch_gpu @require_torch_gpu
......
...@@ -41,8 +41,11 @@ from diffusers.utils import logging ...@@ -41,8 +41,11 @@ from diffusers.utils import logging
from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available
from diffusers.utils.testing_utils import ( from diffusers.utils.testing_utils import (
CaptureLogger, CaptureLogger,
nightly,
require_torch, require_torch,
require_torch_multi_gpu,
skip_mps, skip_mps,
slow,
torch_device, torch_device,
) )
...@@ -59,6 +62,10 @@ from ..models.unets.test_models_unet_2d_condition import ( ...@@ -59,6 +62,10 @@ from ..models.unets.test_models_unet_2d_condition import (
from ..others.test_utils import TOKEN, USER, is_staging_test from ..others.test_utils import TOKEN, USER, is_staging_test
if is_accelerate_available():
from accelerate.utils import compute_module_sizes
def to_np(tensor): def to_np(tensor):
if isinstance(tensor, torch.Tensor): if isinstance(tensor, torch.Tensor):
tensor = tensor.detach().cpu().numpy() tensor = tensor.detach().cpu().numpy()
...@@ -1908,6 +1915,78 @@ class PipelineTesterMixin: ...@@ -1908,6 +1915,78 @@ class PipelineTesterMixin:
) )
) )
@require_torch_multi_gpu
@slow
@nightly
def test_calling_to_raises_error_device_mapped_components(self, safe_serialization=True):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
max_model_size = max(
compute_module_sizes(module)[""]
for _, module in pipe.components.items()
if isinstance(module, torch.nn.Module)
)
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, safe_serialization=safe_serialization)
max_memory = {0: max_model_size, 1: max_model_size}
loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, device_map="balanced", max_memory=max_memory)
with self.assertRaises(ValueError) as err_context:
loaded_pipe.to(torch_device)
self.assertTrue(
"The following pipeline components have been found" in str(err_context.exception)
and "This is incompatible with explicitly setting the device using `to()`" in str(err_context.exception)
)
@require_torch_multi_gpu
@slow
@nightly
def test_calling_mco_raises_error_device_mapped_components(self, safe_serialization=True):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
max_model_size = max(
compute_module_sizes(module)[""]
for _, module in pipe.components.items()
if isinstance(module, torch.nn.Module)
)
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, safe_serialization=safe_serialization)
max_memory = {0: max_model_size, 1: max_model_size}
loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, device_map="balanced", max_memory=max_memory)
with self.assertRaises(ValueError) as err_context:
loaded_pipe.enable_model_cpu_offload()
self.assertTrue(
"The following pipeline components have been found" in str(err_context.exception)
and "This is incompatible with `enable_model_cpu_offload()`" in str(err_context.exception)
)
@require_torch_multi_gpu
@slow
@nightly
def test_calling_sco_raises_error_device_mapped_components(self, safe_serialization=True):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
max_model_size = max(
compute_module_sizes(module)[""]
for _, module in pipe.components.items()
if isinstance(module, torch.nn.Module)
)
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, safe_serialization=safe_serialization)
max_memory = {0: max_model_size, 1: max_model_size}
loaded_pipe = self.pipeline_class.from_pretrained(tmpdir, device_map="balanced", max_memory=max_memory)
with self.assertRaises(ValueError) as err_context:
loaded_pipe.enable_sequential_cpu_offload()
self.assertTrue(
"The following pipeline components have been found" in str(err_context.exception)
and "This is incompatible with `enable_sequential_cpu_offload()`" in str(err_context.exception)
)
@is_staging_test @is_staging_test
class PipelinePushToHubTester(unittest.TestCase): class PipelinePushToHubTester(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment