Unverified Commit 7956c36a authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

add a `from_pipe` method to `DiffusionPipeline` (#7241)



* add from_pipe



---------
Co-authored-by: default avataryiyixuxu <yixu310@gmail,com>
Co-authored-by: default avatarSayak Paul <spsayakpaul@gmail.com>
Co-authored-by: default avatarDhruv Nair <dhruv.nair@gmail.com>
parent 5266ab79
...@@ -179,6 +179,210 @@ stable_diffusion_img2img = StableDiffusionImg2ImgPipeline( ...@@ -179,6 +179,210 @@ stable_diffusion_img2img = StableDiffusionImg2ImgPipeline(
) )
``` ```
### Switch loaded pippelines
There are many diffuser pipelines that use the same pre-trained model as [`StableDiffusionPipeline`] and [`StableDiffusionXLPipeline`], but they implement specific features to help you achieve better generation results. This guide will show you how to use the `from_pipe` API to create multiple pipelines without increasing memory usage. By using this approach, you can easily switch between pipelines to use different features.
Let's take an example where we first create a [`StableDiffusionPipeline`] and then reuse the already loaded model components to create a [`StableDiffusionSAGPipeline`] to enhance generation quality.
we will generate an image of a bear eating pizza using Stable Diffusion with the IP-Adapter
```python
from diffusers import DiffusionPipeline, StableDiffusionSAGPipeline
import torch
import gc
from diffusers.utils import load_image
from accelerate.utils import compute_module_sizes
base_repo = "SG161222/Realistic_Vision_V6.0_B1_noVAE"
num_inference_steps = 50
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png")
prompt="bear eats pizza"
negative_prompt = "wrong white balance, dark, sketches,worst quality,low quality"
pipe_sd = DiffusionPipeline.from_pretrained(base_repo, torch_dtype=torch.float16)
pipe_sd.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
pipe_sd.set_ip_adapter_scale(0.6)
pipe_sd.to("cuda")
generator = torch.Generator(device="cpu").manual_seed(33)
out_sd = pipe_sd(
prompt=prompt,
negative_prompt=negative_prompt,
ip_adapter_image=image,
num_inference_steps=num_inference_steps,
generator=generator,
).images[0]
```
let’s take a look at the image and also print out the memory used
<div class="flex justify-center">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/from_pipe_out_sd_0.png"/>
</div>
```python
def bytes_to_giga_bytes(bytes):
return bytes / 1024 / 1024 / 1024
print(
f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
)
```
```bash
Max memory allocated: 4.406213283538818 GB
```
Now, we can use `from_pipe` to switch to the SAG pipeline.
```python
pipe_sag = StableDiffusionSAGPipeline.from_pipe(
pipe_sd,
)
```
It already has IP-Adapter loaded so that you can pass the same bear image as `ip_adapter_image`
```python
generator = torch.Generator(device="cpu").manual_seed(33)
out_sag = pipe_sag(
prompt = prompt,
negative_prompt=negative_prompt,
ip_adapter_image=image,
num_inference_steps=num_inference_steps,
generator=generator,
guidance_scale=1.0,
sag_scale=0.75).images[0]
```
You can see a pretty nice improvement in the output
<div class="flex justify-center">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/from_pipe_out_sag_1.png"/>
</div>
Now we have both `stableDiffusionPipeline` and `StableDiffusionSAGPipeline` co-existing with the same loaded model components; You can use them interchangeably without additional memory.
```
print(
f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
)
```
```bash
Max memory allocated: 4.406213283538818 GB
```
Let's unload the IP adapter from the SAG pipeline. It's important to note that methods like `load_ip_adapter` and `unload_ip_adapter` modify the state of the model components. Therefore, when you use these methods on one pipeline, it will affect all other pipelines that share the same model components.
```bash
pipe_sag.unload_ip_adapter()
```
If you try to use the Stable Diffusion pipeline with IP adapter again, it will fail
```bash
generator = torch.Generator(device="cpu").manual_seed(33)
out_sd = pipe_sd(
prompt=prompt,
negative_prompt=negative_prompt,
ip_adapter_image=image,
num_inference_steps=num_inference_steps,
generator=generator,
).images[0]
```
```bash
AttributeError: 'NoneType' object has no attribute 'image_projection_layers'
```
Please note that the pipeline methods may not function properly on a new pipeline created using the `from_pipe` method. For instance, the `enable_model_cpu_offload` method installs hooks to the model components based on a unique offloading sequence for each pipeline. Therefore, if the models are executed in a different order in the new pipeline, the CPU offloading may not work correctly.
To ensure proper functionality, we recommend re-applying the pipeline methods on the new pipeline created using the `from_pipe` method.
You can also add or subtract model components when you create new pipelines. Let's now create a AnimateDiff pipeline with an additional `MotionAdapter` module
```bash
from diffusers import AnimateDiffPipeline, MotionAdapter, DDIMScheduler
from diffusers.utils import export_to_gif
adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)
pipe_animate = AnimateDiffPipeline.from_pipe(pipe_sd, motion_adapter=adapter)
pipe_animate.scheduler = DDIMScheduler.from_config(pipe_animate.scheduler.config, beta_schedule="linear")
# load ip_adapter again and load lora weights
pipe_animate.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")
pipe_animate.load_lora_weights("guoyww/animatediff-motion-lora-zoom-out", adapter_name="zoom-out")
pipe_animate.to("cuda")
generator = torch.Generator(device="cpu").manual_seed(33)
pipe_animate.set_adapters("zoom-out", adapter_weights=0.75)
out = pipe_animate(
prompt= prompt,
num_frames=16,
num_inference_steps=num_inference_steps,
ip_adapter_image = image,
generator=generator,
).frames[0]
export_to_gif(out, "out_animate.gif")
```
<div class="flex justify-center">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/from_pipe_out_animate_3.gif"/>
</div>
When creating multiple pipelines using the `from_pipe` method, it is important to note that the memory requirement will be determined by the pipeline with the highest memory usage. This means that regardless of the number of pipelines you create, the total memory requirement will always be the same as the highest memory requirement among the pipelines.
For example, we have created three pipelines - `stableDiffusionPipeline`, `StableDiffusionSAGPipeline`, and `AnimateDiffPipeline` - and the `AnimateDiffPipeline` has the highest memory requirement, then the total memory usage will be based on the memory requirement of the `AnimateDiffPipeline`.
Therefore, creating additional pipelines will not add up to the total memory requirement. Each pipeline can be used interchangeably without any additional memory overhead.
Did you know that you can use `from_pipe` with a community pipeline? Let me show you an example of using long negative prompt and prompt weighting!
```bash
pipe_lpw = DiffusionPipeline.from_pipe(
pipe_sd,
custom_pipeline="lpw_stable_diffusion",
).to("cuda")
prompt = "best_quality (1girl:1.3) bow bride brown_hair closed_mouth frilled_bow frilled_hair_tubes frills (full_body:1.3) fox_ear hair_bow hair_tubes happy hood japanese_clothes kimono long_sleeves red_bow smile solo tabi uchikake white_kimono wide_sleeves cherry_blossoms"
neg_prompt = "lowres, bad_anatomy, error_body, error_hair, error_arm, error_hands, bad_hands, error_fingers, bad_fingers, missing_fingers, error_legs, bad_legs, multiple_legs, missing_legs, error_lighting, error_shadow, error_reflection, text, error, extra_digit, fewer_digits, cropped, worst_quality, low_quality, normal_quality, jpeg_artifacts, signature, watermark, username, blurry"
generator = torch.Generator(device="cpu").manual_seed(33)
out_lpw = pipe_lpw.text2img(
prompt,
negative_prompt=neg_prompt,
width=512,height=512,
max_embeddings_multiples=3,
num_inference_steps=num_inference_steps,
generator=generator,
).images[0]
```
<div class="flex justify-center">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/from_pipe_out_lpw_4.png"/>
</div>
let’s run StableDiffusionPipeline with the same inputs to compare: the result from the long prompt weighting pipeline is more aligned with the text prompt.
```
generator = torch.Generator(device="cpu").manual_seed(33)
out_sd = pipe_sd(
prompt=prompt,
negative_prompt=negative_prompt,
generator=generator,
num_inference_steps=num_inference_steps,
).images[0]
out_sd
```
<div class="flex justify-center">
<img class="rounded-xl" src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/from_pipe_out_sd_5.png"/>
</div>
You can easily switch between different pipelines using the `from_pipe` method, similar to turning on and off a feature on your pipeline. To switch between tasks, you can use the `from_pipe` method with `AutoPipeline`, which automatically identifies the pipeline class based on the task. You can find more information about this feature at the [AutoPipe Guide](https://huggingface.co/docs/diffusers/tutorials/autopipeline).
## Checkpoint variants ## Checkpoint variants
A checkpoint variant is usually a checkpoint whose weights are: A checkpoint variant is usually a checkpoint whose weights are:
......
...@@ -439,7 +439,9 @@ class StableDiffusionLongPromptWeightingPipeline( ...@@ -439,7 +439,9 @@ class StableDiffusionLongPromptWeightingPipeline(
Model that extracts features from generated images to be used as inputs for the `safety_checker`. Model that extracts features from generated images to be used as inputs for the `safety_checker`.
""" """
model_cpu_offload_seq = "text_encoder-->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"] _optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
def __init__( def __init__(
self, self,
......
...@@ -17,7 +17,7 @@ import torch ...@@ -17,7 +17,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.utils.checkpoint import torch.utils.checkpoint
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ...loaders import UNet2DConditionLoadersMixin from ...loaders import UNet2DConditionLoadersMixin
from ...utils import logging from ...utils import logging
from ..attention_processor import ( from ..attention_processor import (
...@@ -393,8 +393,11 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -393,8 +393,11 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
): ):
has_motion_adapter = motion_adapter is not None has_motion_adapter = motion_adapter is not None
if has_motion_adapter:
motion_adapter.to(device=unet.device)
# based on https://github.com/guoyww/AnimateDiff/blob/895f3220c06318ea0760131ec70408b466c49333/animatediff/models/unet.py#L459 # based on https://github.com/guoyww/AnimateDiff/blob/895f3220c06318ea0760131ec70408b466c49333/animatediff/models/unet.py#L459
config = unet.config config = dict(unet.config)
config["_class_name"] = cls.__name__ config["_class_name"] = cls.__name__
down_blocks = [] down_blocks = []
...@@ -427,6 +430,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -427,6 +430,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
if not config.get("num_attention_heads"): if not config.get("num_attention_heads"):
config["num_attention_heads"] = config["attention_head_dim"] config["num_attention_heads"] = config["attention_head_dim"]
config = FrozenDict(config)
model = cls.from_config(config) model = cls.from_config(config)
if not load_weights: if not load_weights:
......
...@@ -131,7 +131,7 @@ class AnimateDiffPipeline( ...@@ -131,7 +131,7 @@ class AnimateDiffPipeline(
vae: AutoencoderKL, vae: AutoencoderKL,
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer, tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel, unet: Union[UNet2DConditionModel, UNetMotionModel],
motion_adapter: MotionAdapter, motion_adapter: MotionAdapter,
scheduler: Union[ scheduler: Union[
DDIMScheduler, DDIMScheduler,
......
...@@ -292,6 +292,39 @@ def get_class_obj_and_candidates( ...@@ -292,6 +292,39 @@ def get_class_obj_and_candidates(
return class_obj, class_candidates return class_obj, class_candidates
def _get_custom_pipeline_class(
custom_pipeline,
repo_id=None,
hub_revision=None,
class_name=None,
cache_dir=None,
revision=None,
):
if custom_pipeline.endswith(".py"):
path = Path(custom_pipeline)
# decompose into folder & file
file_name = path.name
custom_pipeline = path.parent.absolute()
elif repo_id is not None:
file_name = f"{custom_pipeline}.py"
custom_pipeline = repo_id
else:
file_name = CUSTOM_PIPELINE_FILE_NAME
if repo_id is not None and hub_revision is not None:
# if we load the pipeline code from the Hub
# make sure to overwrite the `revision`
revision = hub_revision
return get_class_from_dynamic_module(
custom_pipeline,
module_file=file_name,
class_name=class_name,
cache_dir=cache_dir,
revision=revision,
)
def _get_pipeline_class( def _get_pipeline_class(
class_obj, class_obj,
config=None, config=None,
...@@ -304,25 +337,10 @@ def _get_pipeline_class( ...@@ -304,25 +337,10 @@ def _get_pipeline_class(
revision=None, revision=None,
): ):
if custom_pipeline is not None: if custom_pipeline is not None:
if custom_pipeline.endswith(".py"): return _get_custom_pipeline_class(
path = Path(custom_pipeline)
# decompose into folder & file
file_name = path.name
custom_pipeline = path.parent.absolute()
elif repo_id is not None:
file_name = f"{custom_pipeline}.py"
custom_pipeline = repo_id
else:
file_name = CUSTOM_PIPELINE_FILE_NAME
if repo_id is not None and hub_revision is not None:
# if we load the pipeline code from the Hub
# make sure to overwrite the `revision`
revision = hub_revision
return get_class_from_dynamic_module(
custom_pipeline, custom_pipeline,
module_file=file_name, repo_id=repo_id,
hub_revision=hub_revision,
class_name=class_name, class_name=class_name,
cache_dir=cache_dir, cache_dir=cache_dir,
revision=revision, revision=revision,
......
...@@ -21,7 +21,7 @@ import re ...@@ -21,7 +21,7 @@ import re
import sys import sys
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin
import numpy as np import numpy as np
import PIL.Image import PIL.Image
...@@ -43,7 +43,7 @@ from .. import __version__ ...@@ -43,7 +43,7 @@ from .. import __version__
from ..configuration_utils import ConfigMixin from ..configuration_utils import ConfigMixin
from ..models import AutoencoderKL from ..models import AutoencoderKL
from ..models.attention_processor import FusedAttnProcessor2_0 from ..models.attention_processor import FusedAttnProcessor2_0
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from ..utils import ( from ..utils import (
CONFIG_NAME, CONFIG_NAME,
...@@ -72,6 +72,7 @@ from .pipeline_loading_utils import ( ...@@ -72,6 +72,7 @@ from .pipeline_loading_utils import (
CUSTOM_PIPELINE_FILE_NAME, CUSTOM_PIPELINE_FILE_NAME,
LOADABLE_CLASSES, LOADABLE_CLASSES,
_fetch_class_library_tuple, _fetch_class_library_tuple,
_get_custom_pipeline_class,
_get_pipeline_class, _get_pipeline_class,
_unwrap_model, _unwrap_model,
is_safetensors_compatible, is_safetensors_compatible,
...@@ -1475,6 +1476,18 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1475,6 +1476,18 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
return expected_modules, optional_parameters return expected_modules, optional_parameters
@classmethod
def _get_signature_types(cls):
signature_types = {}
for k, v in inspect.signature(cls.__init__).parameters.items():
if inspect.isclass(v.annotation):
signature_types[k] = (v.annotation,)
elif get_origin(v.annotation) == Union:
signature_types[k] = get_args(v.annotation)
else:
logger.warning(f"cannot get type annotation for Parameter {k} of {cls}.")
return signature_types
@property @property
def components(self) -> Dict[str, Any]: def components(self) -> Dict[str, Any]:
r""" r"""
...@@ -1653,6 +1666,128 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): ...@@ -1653,6 +1666,128 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
for module in modules: for module in modules:
module.set_attention_slice(slice_size) module.set_attention_slice(slice_size)
@classmethod
def from_pipe(cls, pipeline, **kwargs):
r"""
Create a new pipeline from a given pipeline. This method is useful to create a new pipeline from the existing pipeline components without reallocating additional memory.
Arguments:
pipeline (`DiffusionPipeline`):
The pipeline from which to create a new pipeline.
Returns:
`DiffusionPipeline`:
A new pipeline with the same weights and configurations as `pipeline`.
Examples:
```py
>>> from diffusers import StableDiffusionPipeline, StableDiffusionSAGPipeline
>>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
>>> new_pipe = StableDiffusionSAGPipeline.from_pipe(pipe)
```
"""
original_config = dict(pipeline.config)
torch_dtype = kwargs.pop("torch_dtype", None)
# derive the pipeline class to instantiate
custom_pipeline = kwargs.pop("custom_pipeline", None)
custom_revision = kwargs.pop("custom_revision", None)
if custom_pipeline is not None:
pipeline_class = _get_custom_pipeline_class(custom_pipeline, revision=custom_revision)
else:
pipeline_class = cls
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
# true_optional_modules are optional components with default value in signature so it is ok not to pass them to `__init__`
# e.g. `image_encoder` for StableDiffusionPipeline
parameters = inspect.signature(cls.__init__).parameters
true_optional_modules = set(
{k for k, v in parameters.items() if v.default != inspect._empty and k in expected_modules}
)
# get the class of each component based on its type hint
# e.g. {"unet": UNet2DConditionModel, "text_encoder": CLIPTextMode}
component_types = pipeline_class._get_signature_types()
pretrained_model_name_or_path = original_config.pop("_name_or_path", None)
# allow users pass modules in `kwargs` to override the original pipeline's components
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
original_class_obj = {}
for name, component in pipeline.components.items():
if name in expected_modules and name not in passed_class_obj:
# for model components, we will not switch over if the class does not matches the type hint in the new pipeline's signature
if (
not isinstance(component, ModelMixin)
or type(component) in component_types[name]
or (component is None and name in cls._optional_components)
):
original_class_obj[name] = component
else:
logger.warn(
f"component {name} is not switched over to new pipeline because type does not match the expected."
f" {name} is {type(component)} while the new pipeline expect {component_types[name]}."
f" please pass the component of the correct type to the new pipeline. `from_pipe(..., {name}={name})`"
)
# allow users pass optional kwargs to override the original pipelines config attribute
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
original_pipe_kwargs = {
k: original_config[k]
for k in original_config.keys()
if k in optional_kwargs and k not in passed_pipe_kwargs
}
# config attribute that were not expected by pipeline is stored as its private attribute
# (i.e. when the original pipeline was also instantiated with `from_pipe` from another pipeline that has this config)
# in this case, we will pass them as optional arguments if they can be accepted by the new pipeline
additional_pipe_kwargs = [
k[1:]
for k in original_config.keys()
if k.startswith("_") and k[1:] in optional_kwargs and k[1:] not in passed_pipe_kwargs
]
for k in additional_pipe_kwargs:
original_pipe_kwargs[k] = original_config.pop(f"_{k}")
pipeline_kwargs = {
**passed_class_obj,
**original_class_obj,
**passed_pipe_kwargs,
**original_pipe_kwargs,
**kwargs,
}
# store unused config as private attribute in the new pipeline
unused_original_config = {
f"{'' if k.startswith('_') else '_'}{k}": v for k, v in original_config.items() if k not in pipeline_kwargs
}
missing_modules = (
set(expected_modules)
- set(pipeline._optional_components)
- set(pipeline_kwargs.keys())
- set(true_optional_modules)
)
if len(missing_modules) > 0:
raise ValueError(
f"Pipeline {pipeline_class} expected {expected_modules}, but only {set(list(passed_class_obj.keys()) + list(original_class_obj.keys()))} were passed"
)
new_pipeline = pipeline_class(**pipeline_kwargs)
if pretrained_model_name_or_path is not None:
new_pipeline.register_to_config(_name_or_path=pretrained_model_name_or_path)
new_pipeline.register_to_config(**unused_original_config)
if torch_dtype is not None:
new_pipeline.to(dtype=torch_dtype)
return new_pipeline
class StableDiffusionMixin: class StableDiffusionMixin:
r""" r"""
......
...@@ -902,6 +902,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, StableDiffusionM ...@@ -902,6 +902,7 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, StableDiffusionM
if attn_res is None: if attn_res is None:
attn_res = int(np.ceil(width / 32)), int(np.ceil(height / 32)) attn_res = int(np.ceil(width / 32)), int(np.ceil(height / 32))
self.attention_store = AttentionStore(attn_res) self.attention_store = AttentionStore(attn_res)
original_attn_proc = self.unet.attn_processors
self.register_attention_control() self.register_attention_control()
# default config for step size from original repo # default config for step size from original repo
...@@ -1016,6 +1017,8 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, StableDiffusionM ...@@ -1016,6 +1017,8 @@ class StableDiffusionAttendAndExcitePipeline(DiffusionPipeline, StableDiffusionM
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
self.maybe_free_model_hooks() self.maybe_free_model_hooks()
# make sure to set the original attention processors back
self.unet.set_attn_processor(original_attn_proc)
if not return_dict: if not return_dict:
return (image, has_nsfw_concept) return (image, has_nsfw_concept)
......
...@@ -750,6 +750,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, StableDiffusionMixin, Textua ...@@ -750,6 +750,7 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, StableDiffusionMixin, Textua
) )
# 7. Denoising loop # 7. Denoising loop
original_attn_proc = self.unet.attn_processors
store_processor = CrossAttnStoreProcessor() store_processor = CrossAttnStoreProcessor()
self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
...@@ -848,6 +849,8 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, StableDiffusionMixin, Textua ...@@ -848,6 +849,8 @@ class StableDiffusionSAGPipeline(DiffusionPipeline, StableDiffusionMixin, Textua
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
self.maybe_free_model_hooks() self.maybe_free_model_hooks()
# make sure to set the original attention processors back
self.unet.set_attn_processor(original_attn_proc)
if not return_dict: if not return_dict:
return (image, has_nsfw_concept) return (image, has_nsfw_concept)
......
...@@ -329,13 +329,6 @@ class TextToVideoZeroPipeline(DiffusionPipeline, StableDiffusionMixin, TextualIn ...@@ -329,13 +329,6 @@ class TextToVideoZeroPipeline(DiffusionPipeline, StableDiffusionMixin, TextualIn
safety_checker=safety_checker, safety_checker=safety_checker,
feature_extractor=feature_extractor, feature_extractor=feature_extractor,
) )
processor = (
CrossFrameAttnProcessor2_0(batch_size=2)
if hasattr(F, "scaled_dot_product_attention")
else CrossFrameAttnProcessor(batch_size=2)
)
self.unet.set_attn_processor(processor)
if safety_checker is None and requires_safety_checker: if safety_checker is None and requires_safety_checker:
logger.warning( logger.warning(
f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
...@@ -616,6 +609,15 @@ class TextToVideoZeroPipeline(DiffusionPipeline, StableDiffusionMixin, TextualIn ...@@ -616,6 +609,15 @@ class TextToVideoZeroPipeline(DiffusionPipeline, StableDiffusionMixin, TextualIn
assert num_videos_per_prompt == 1 assert num_videos_per_prompt == 1
# set the processor
original_attn_proc = self.unet.attn_processors
processor = (
CrossFrameAttnProcessor2_0(batch_size=2)
if hasattr(F, "scaled_dot_product_attention")
else CrossFrameAttnProcessor(batch_size=2)
)
self.unet.set_attn_processor(processor)
if isinstance(prompt, str): if isinstance(prompt, str):
prompt = [prompt] prompt = [prompt]
if isinstance(negative_prompt, str): if isinstance(negative_prompt, str):
...@@ -739,6 +741,8 @@ class TextToVideoZeroPipeline(DiffusionPipeline, StableDiffusionMixin, TextualIn ...@@ -739,6 +741,8 @@ class TextToVideoZeroPipeline(DiffusionPipeline, StableDiffusionMixin, TextualIn
# Offload all models # Offload all models
self.maybe_free_model_hooks() self.maybe_free_model_hooks()
# make sure to set the original attention processors back
self.unet.set_attn_processor(original_attn_proc)
if not return_dict: if not return_dict:
return (image, has_nsfw_concept) return (image, has_nsfw_concept)
......
...@@ -411,14 +411,6 @@ class TextToVideoZeroSDXLPipeline( ...@@ -411,14 +411,6 @@ class TextToVideoZeroSDXLPipeline(
else: else:
self.watermark = None self.watermark = None
processor = (
CrossFrameAttnProcessor2_0(batch_size=2)
if hasattr(F, "scaled_dot_product_attention")
else CrossFrameAttnProcessor(batch_size=2)
)
self.unet.set_attn_processor(processor)
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta): def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
...@@ -1084,6 +1076,15 @@ class TextToVideoZeroSDXLPipeline( ...@@ -1084,6 +1076,15 @@ class TextToVideoZeroSDXLPipeline(
assert num_videos_per_prompt == 1 assert num_videos_per_prompt == 1
# set the processor
original_attn_proc = self.unet.attn_processors
processor = (
CrossFrameAttnProcessor2_0(batch_size=2)
if hasattr(F, "scaled_dot_product_attention")
else CrossFrameAttnProcessor(batch_size=2)
)
self.unet.set_attn_processor(processor)
if isinstance(prompt, str): if isinstance(prompt, str):
prompt = [prompt] prompt = [prompt]
if isinstance(negative_prompt, str): if isinstance(negative_prompt, str):
...@@ -1305,9 +1306,9 @@ class TextToVideoZeroSDXLPipeline( ...@@ -1305,9 +1306,9 @@ class TextToVideoZeroSDXLPipeline(
image = self.image_processor.postprocess(image, output_type=output_type) image = self.image_processor.postprocess(image, output_type=output_type)
# Offload last model to CPU manually for max memory savings self.maybe_free_model_hooks()
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: # make sure to set the original attention processors back
self.final_offload_hook.offload() self.unet.set_attn_processor(original_attn_proc)
if not return_dict: if not return_dict:
return (image,) return (image,)
......
...@@ -18,7 +18,12 @@ from diffusers.utils import is_xformers_available, logging ...@@ -18,7 +18,12 @@ from diffusers.utils import is_xformers_available, logging
from diffusers.utils.testing_utils import numpy_cosine_similarity_distance, require_torch_gpu, slow, torch_device from diffusers.utils.testing_utils import numpy_cosine_similarity_distance, require_torch_gpu, slow, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import IPAdapterTesterMixin, PipelineTesterMixin, SDFunctionTesterMixin from ..test_pipelines_common import (
IPAdapterTesterMixin,
PipelineFromPipeTesterMixin,
PipelineTesterMixin,
SDFunctionTesterMixin,
)
def to_np(tensor): def to_np(tensor):
...@@ -29,7 +34,7 @@ def to_np(tensor): ...@@ -29,7 +34,7 @@ def to_np(tensor):
class AnimateDiffPipelineFastTests( class AnimateDiffPipelineFastTests(
IPAdapterTesterMixin, SDFunctionTesterMixin, PipelineTesterMixin, unittest.TestCase IPAdapterTesterMixin, SDFunctionTesterMixin, PipelineTesterMixin, PipelineFromPipeTesterMixin, unittest.TestCase
): ):
pipeline_class = AnimateDiffPipeline pipeline_class = AnimateDiffPipeline
params = TEXT_TO_IMAGE_PARAMS params = TEXT_TO_IMAGE_PARAMS
......
...@@ -18,7 +18,7 @@ from diffusers.utils import is_xformers_available, logging ...@@ -18,7 +18,7 @@ from diffusers.utils import is_xformers_available, logging
from diffusers.utils.testing_utils import torch_device from diffusers.utils.testing_utils import torch_device
from ..pipeline_params import TEXT_TO_IMAGE_PARAMS, VIDEO_TO_VIDEO_BATCH_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_PARAMS, VIDEO_TO_VIDEO_BATCH_PARAMS
from ..test_pipelines_common import IPAdapterTesterMixin, PipelineTesterMixin from ..test_pipelines_common import IPAdapterTesterMixin, PipelineFromPipeTesterMixin, PipelineTesterMixin
def to_np(tensor): def to_np(tensor):
...@@ -28,7 +28,9 @@ def to_np(tensor): ...@@ -28,7 +28,9 @@ def to_np(tensor):
return tensor return tensor
class AnimateDiffVideoToVideoPipelineFastTests(IPAdapterTesterMixin, PipelineTesterMixin, unittest.TestCase): class AnimateDiffVideoToVideoPipelineFastTests(
IPAdapterTesterMixin, PipelineTesterMixin, PipelineFromPipeTesterMixin, unittest.TestCase
):
pipeline_class = AnimateDiffVideoToVideoPipeline pipeline_class = AnimateDiffVideoToVideoPipeline
params = TEXT_TO_IMAGE_PARAMS params = TEXT_TO_IMAGE_PARAMS
batch_params = VIDEO_TO_VIDEO_BATCH_PARAMS batch_params = VIDEO_TO_VIDEO_BATCH_PARAMS
......
...@@ -17,7 +17,7 @@ from diffusers import ( ...@@ -17,7 +17,7 @@ from diffusers import (
from diffusers.utils import is_xformers_available, logging from diffusers.utils import is_xformers_available, logging
from diffusers.utils.testing_utils import floats_tensor, torch_device from diffusers.utils.testing_utils import floats_tensor, torch_device
from ..test_pipelines_common import IPAdapterTesterMixin, PipelineTesterMixin from ..test_pipelines_common import IPAdapterTesterMixin, PipelineFromPipeTesterMixin, PipelineTesterMixin
def to_np(tensor): def to_np(tensor):
...@@ -27,7 +27,7 @@ def to_np(tensor): ...@@ -27,7 +27,7 @@ def to_np(tensor):
return tensor return tensor
class PIAPipelineFastTests(IPAdapterTesterMixin, PipelineTesterMixin, unittest.TestCase): class PIAPipelineFastTests(IPAdapterTesterMixin, PipelineTesterMixin, PipelineFromPipeTesterMixin, unittest.TestCase):
pipeline_class = PIAPipeline pipeline_class = PIAPipeline
params = frozenset( params = frozenset(
[ [
......
...@@ -35,7 +35,12 @@ from diffusers.utils.testing_utils import ( ...@@ -35,7 +35,12 @@ from diffusers.utils.testing_utils import (
) )
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin from ..test_pipelines_common import (
PipelineFromPipeTesterMixin,
PipelineKarrasSchedulerTesterMixin,
PipelineLatentTesterMixin,
PipelineTesterMixin,
)
torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cuda.matmul.allow_tf32 = False
...@@ -43,7 +48,11 @@ torch.backends.cuda.matmul.allow_tf32 = False ...@@ -43,7 +48,11 @@ torch.backends.cuda.matmul.allow_tf32 = False
@skip_mps @skip_mps
class StableDiffusionAttendAndExcitePipelineFastTests( class StableDiffusionAttendAndExcitePipelineFastTests(
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase PipelineLatentTesterMixin,
PipelineKarrasSchedulerTesterMixin,
PipelineTesterMixin,
PipelineFromPipeTesterMixin,
unittest.TestCase,
): ):
pipeline_class = StableDiffusionAttendAndExcitePipeline pipeline_class = StableDiffusionAttendAndExcitePipeline
test_attention_slicing = False test_attention_slicing = False
......
...@@ -43,13 +43,15 @@ from diffusers.utils.testing_utils import ( ...@@ -43,13 +43,15 @@ from diffusers.utils.testing_utils import (
) )
from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin from ..test_pipelines_common import PipelineFromPipeTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin
enable_full_determinism() enable_full_determinism()
class StableDiffusionDiffEditPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): class StableDiffusionDiffEditPipelineFastTests(
PipelineLatentTesterMixin, PipelineTesterMixin, PipelineFromPipeTesterMixin, unittest.TestCase
):
pipeline_class = StableDiffusionDiffEditPipeline pipeline_class = StableDiffusionDiffEditPipeline
params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS - {"height", "width", "image"} | {"image_latents"} params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS - {"height", "width", "image"} | {"image_latents"}
batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS - {"image"} | {"image_latents"} batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS - {"image"} | {"image_latents"}
......
...@@ -46,7 +46,7 @@ from diffusers.utils.testing_utils import ( ...@@ -46,7 +46,7 @@ from diffusers.utils.testing_utils import (
) )
from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference from ..test_pipelines_common import PipelineFromPipeTesterMixin, PipelineTesterMixin, assert_mean_pixel_difference
enable_full_determinism() enable_full_determinism()
...@@ -337,7 +337,9 @@ class AdapterTests: ...@@ -337,7 +337,9 @@ class AdapterTests:
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
class StableDiffusionFullAdapterPipelineFastTests(AdapterTests, PipelineTesterMixin, unittest.TestCase): class StableDiffusionFullAdapterPipelineFastTests(
AdapterTests, PipelineTesterMixin, PipelineFromPipeTesterMixin, unittest.TestCase
):
def get_dummy_components(self, time_cond_proj_dim=None): def get_dummy_components(self, time_cond_proj_dim=None):
return super().get_dummy_components("full_adapter", time_cond_proj_dim=time_cond_proj_dim) return super().get_dummy_components("full_adapter", time_cond_proj_dim=time_cond_proj_dim)
......
...@@ -33,14 +33,23 @@ from ..pipeline_params import ( ...@@ -33,14 +33,23 @@ from ..pipeline_params import (
TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS,
) )
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin from ..test_pipelines_common import (
PipelineFromPipeTesterMixin,
PipelineKarrasSchedulerTesterMixin,
PipelineLatentTesterMixin,
PipelineTesterMixin,
)
enable_full_determinism() enable_full_determinism()
class GligenPipelineFastTests( class GligenPipelineFastTests(
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase PipelineLatentTesterMixin,
PipelineKarrasSchedulerTesterMixin,
PipelineTesterMixin,
PipelineFromPipeTesterMixin,
unittest.TestCase,
): ):
pipeline_class = StableDiffusionGLIGENPipeline pipeline_class = StableDiffusionGLIGENPipeline
params = TEXT_TO_IMAGE_PARAMS | {"gligen_phrases", "gligen_boxes"} params = TEXT_TO_IMAGE_PARAMS | {"gligen_phrases", "gligen_boxes"}
......
...@@ -42,14 +42,23 @@ from ..pipeline_params import ( ...@@ -42,14 +42,23 @@ from ..pipeline_params import (
TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS,
TEXT_TO_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS,
) )
from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin from ..test_pipelines_common import (
PipelineFromPipeTesterMixin,
PipelineKarrasSchedulerTesterMixin,
PipelineLatentTesterMixin,
PipelineTesterMixin,
)
enable_full_determinism() enable_full_determinism()
class GligenTextImagePipelineFastTests( class GligenTextImagePipelineFastTests(
PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase PipelineLatentTesterMixin,
PipelineKarrasSchedulerTesterMixin,
PipelineTesterMixin,
PipelineFromPipeTesterMixin,
unittest.TestCase,
): ):
pipeline_class = StableDiffusionGLIGENTextImagePipeline pipeline_class = StableDiffusionGLIGENTextImagePipeline
params = TEXT_TO_IMAGE_PARAMS | {"gligen_phrases", "gligen_images", "gligen_boxes"} params = TEXT_TO_IMAGE_PARAMS | {"gligen_phrases", "gligen_images", "gligen_boxes"}
......
...@@ -32,7 +32,12 @@ from diffusers import ( ...@@ -32,7 +32,12 @@ from diffusers import (
from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, skip_mps, torch_device from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, skip_mps, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin from ..test_pipelines_common import (
IPAdapterTesterMixin,
PipelineFromPipeTesterMixin,
PipelineLatentTesterMixin,
PipelineTesterMixin,
)
enable_full_determinism() enable_full_determinism()
...@@ -40,7 +45,11 @@ enable_full_determinism() ...@@ -40,7 +45,11 @@ enable_full_determinism()
@skip_mps @skip_mps
class StableDiffusionPanoramaPipelineFastTests( class StableDiffusionPanoramaPipelineFastTests(
IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase IPAdapterTesterMixin,
PipelineLatentTesterMixin,
PipelineTesterMixin,
PipelineFromPipeTesterMixin,
unittest.TestCase,
): ):
pipeline_class = StableDiffusionPanoramaPipeline pipeline_class = StableDiffusionPanoramaPipeline
params = TEXT_TO_IMAGE_PARAMS params = TEXT_TO_IMAGE_PARAMS
......
...@@ -32,14 +32,23 @@ from diffusers import ( ...@@ -32,14 +32,23 @@ from diffusers import (
from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, torch_device from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin from ..test_pipelines_common import (
IPAdapterTesterMixin,
PipelineFromPipeTesterMixin,
PipelineLatentTesterMixin,
PipelineTesterMixin,
)
enable_full_determinism() enable_full_determinism()
class StableDiffusionSAGPipelineFastTests( class StableDiffusionSAGPipelineFastTests(
IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase IPAdapterTesterMixin,
PipelineLatentTesterMixin,
PipelineTesterMixin,
PipelineFromPipeTesterMixin,
unittest.TestCase,
): ):
pipeline_class = StableDiffusionSAGPipeline pipeline_class = StableDiffusionSAGPipeline
params = TEXT_TO_IMAGE_PARAMS params = TEXT_TO_IMAGE_PARAMS
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment