"torchvision/git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "c67a58396d894c48897b96fce4eae0bcb3fd2755"
Unverified Commit a7bf77fc authored by Aleksei Zhuravlev's avatar Aleksei Zhuravlev Committed by GitHub
Browse files

Passing `cross_attention_kwargs` to `StableDiffusionInstructPix2PixPipeline` (#7961)

* Update pipeline_stable_diffusion_instruct_pix2pix.py

Add `cross_attention_kwargs` to `__call__` method of `StableDiffusionInstructPix2PixPipeline`, which are passed to UNet.

* Update documentation for pipeline_stable_diffusion_instruct_pix2pix.py

* Update docstring

* Update docstring

* Fix typing import
parent 0f0defdb
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
from typing import Callable, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np import numpy as np
import PIL.Image import PIL.Image
...@@ -180,6 +180,7 @@ class StableDiffusionInstructPix2PixPipeline( ...@@ -180,6 +180,7 @@ class StableDiffusionInstructPix2PixPipeline(
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None, ] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"], callback_on_step_end_tensor_inputs: List[str] = ["latents"],
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
**kwargs, **kwargs,
): ):
r""" r"""
...@@ -239,6 +240,9 @@ class StableDiffusionInstructPix2PixPipeline( ...@@ -239,6 +240,9 @@ class StableDiffusionInstructPix2PixPipeline(
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class. `._callback_tensor_inputs` attribute of your pipeline class.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
Examples: Examples:
...@@ -415,6 +419,7 @@ class StableDiffusionInstructPix2PixPipeline( ...@@ -415,6 +419,7 @@ class StableDiffusionInstructPix2PixPipeline(
t, t,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
added_cond_kwargs=added_cond_kwargs, added_cond_kwargs=added_cond_kwargs,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment