Unverified Commit 3e99b567 authored by Parag Ekbote's avatar Parag Ekbote Committed by GitHub
Browse files

Extend Support for callback_on_step_end for AuraFlow and LuminaText2Img Pipelines (#10746)



* Add support for callback_on_step_end for
AuraFlowPipeline and LuminaText2ImgPipeline.

* Apply the suggestions from code review for lumina and auraflow
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Update missing inputs and imports.

* Add input field.

* Apply suggestions from code review-2
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* Apply the suggestions from review for unused imports.
Co-authored-by: default avatarhlky <hlky@hlky.ac>

* make style.

* Update pipeline_aura_flow.py

* Update pipeline_lumina.py

* Update pipeline_lumina.py

* Update pipeline_aura_flow.py

* Update pipeline_lumina.py

---------
Co-authored-by: default avatarhlky <hlky@hlky.ac>
parent 952b9131
...@@ -12,11 +12,12 @@ ...@@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import inspect import inspect
from typing import List, Optional, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, Union
import torch import torch
from transformers import T5Tokenizer, UMT5EncoderModel from transformers import T5Tokenizer, UMT5EncoderModel
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...models import AuraFlowTransformer2DModel, AutoencoderKL from ...models import AuraFlowTransformer2DModel, AutoencoderKL
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
...@@ -131,6 +132,10 @@ class AuraFlowPipeline(DiffusionPipeline): ...@@ -131,6 +132,10 @@ class AuraFlowPipeline(DiffusionPipeline):
_optional_components = [] _optional_components = []
model_cpu_offload_seq = "text_encoder->transformer->vae" model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
]
def __init__( def __init__(
self, self,
...@@ -159,12 +164,19 @@ class AuraFlowPipeline(DiffusionPipeline): ...@@ -159,12 +164,19 @@ class AuraFlowPipeline(DiffusionPipeline):
negative_prompt_embeds=None, negative_prompt_embeds=None,
prompt_attention_mask=None, prompt_attention_mask=None,
negative_prompt_attention_mask=None, negative_prompt_attention_mask=None,
callback_on_step_end_tensor_inputs=None,
): ):
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
raise ValueError( raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}." f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
...@@ -387,6 +399,14 @@ class AuraFlowPipeline(DiffusionPipeline): ...@@ -387,6 +399,14 @@ class AuraFlowPipeline(DiffusionPipeline):
self.vae.decoder.conv_in.to(dtype) self.vae.decoder.conv_in.to(dtype)
self.vae.decoder.mid_block.to(dtype) self.vae.decoder.mid_block.to(dtype)
@property
def guidance_scale(self):
return self._guidance_scale
@property
def num_timesteps(self):
return self._num_timesteps
@torch.no_grad() @torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING) @replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__( def __call__(
...@@ -408,6 +428,10 @@ class AuraFlowPipeline(DiffusionPipeline): ...@@ -408,6 +428,10 @@ class AuraFlowPipeline(DiffusionPipeline):
max_sequence_length: int = 256, max_sequence_length: int = 256,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
) -> Union[ImagePipelineOutput, Tuple]: ) -> Union[ImagePipelineOutput, Tuple]:
r""" r"""
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -462,6 +486,15 @@ class AuraFlowPipeline(DiffusionPipeline): ...@@ -462,6 +486,15 @@ class AuraFlowPipeline(DiffusionPipeline):
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple. of a plain tuple.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`. max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
Examples: Examples:
...@@ -483,8 +516,11 @@ class AuraFlowPipeline(DiffusionPipeline): ...@@ -483,8 +516,11 @@ class AuraFlowPipeline(DiffusionPipeline):
negative_prompt_embeds, negative_prompt_embeds,
prompt_attention_mask, prompt_attention_mask,
negative_prompt_attention_mask, negative_prompt_attention_mask,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
) )
self._guidance_scale = guidance_scale
# 2. Determine batch size. # 2. Determine batch size.
if prompt is not None and isinstance(prompt, str): if prompt is not None and isinstance(prompt, str):
batch_size = 1 batch_size = 1
...@@ -541,6 +577,7 @@ class AuraFlowPipeline(DiffusionPipeline): ...@@ -541,6 +577,7 @@ class AuraFlowPipeline(DiffusionPipeline):
# 6. Denoising loop # 6. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance # expand the latents if we are doing classifier free guidance
...@@ -567,6 +604,15 @@ class AuraFlowPipeline(DiffusionPipeline): ...@@ -567,6 +604,15 @@ class AuraFlowPipeline(DiffusionPipeline):
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided # call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update() progress_bar.update()
......
...@@ -17,11 +17,12 @@ import inspect ...@@ -17,11 +17,12 @@ import inspect
import math import math
import re import re
import urllib.parse as ul import urllib.parse as ul
from typing import List, Optional, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, Union
import torch import torch
from transformers import AutoModel, AutoTokenizer from transformers import AutoModel, AutoTokenizer
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL from ...models import AutoencoderKL
from ...models.embeddings import get_2d_rotary_pos_embed_lumina from ...models.embeddings import get_2d_rotary_pos_embed_lumina
...@@ -174,6 +175,10 @@ class LuminaText2ImgPipeline(DiffusionPipeline): ...@@ -174,6 +175,10 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
_optional_components = [] _optional_components = []
model_cpu_offload_seq = "text_encoder->transformer->vae" model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
]
def __init__( def __init__(
self, self,
...@@ -395,12 +400,20 @@ class LuminaText2ImgPipeline(DiffusionPipeline): ...@@ -395,12 +400,20 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
negative_prompt_embeds=None, negative_prompt_embeds=None,
prompt_attention_mask=None, prompt_attention_mask=None,
negative_prompt_attention_mask=None, negative_prompt_attention_mask=None,
callback_on_step_end_tensor_inputs=None,
): ):
if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
raise ValueError( raise ValueError(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}." f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}."
) )
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None: if prompt is not None and prompt_embeds is not None:
raise ValueError( raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
...@@ -644,6 +657,10 @@ class LuminaText2ImgPipeline(DiffusionPipeline): ...@@ -644,6 +657,10 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
max_sequence_length: int = 256, max_sequence_length: int = 256,
scaling_watershed: Optional[float] = 1.0, scaling_watershed: Optional[float] = 1.0,
proportional_attn: Optional[bool] = True, proportional_attn: Optional[bool] = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
) -> Union[ImagePipelineOutput, Tuple]: ) -> Union[ImagePipelineOutput, Tuple]:
""" """
Function invoked when calling the pipeline for generation. Function invoked when calling the pipeline for generation.
...@@ -735,7 +752,11 @@ class LuminaText2ImgPipeline(DiffusionPipeline): ...@@ -735,7 +752,11 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask, prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
) )
self._guidance_scale = guidance_scale
cross_attention_kwargs = {} cross_attention_kwargs = {}
# 2. Define call parameters # 2. Define call parameters
...@@ -797,6 +818,8 @@ class LuminaText2ImgPipeline(DiffusionPipeline): ...@@ -797,6 +818,8 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
latents, latents,
) )
self._num_timesteps = len(timesteps)
# 6. Denoising loop # 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar: with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps): for i, t in enumerate(timesteps):
...@@ -886,6 +909,15 @@ class LuminaText2ImgPipeline(DiffusionPipeline): ...@@ -886,6 +909,15 @@ class LuminaText2ImgPipeline(DiffusionPipeline):
progress_bar.update() progress_bar.update()
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
if XLA_AVAILABLE: if XLA_AVAILABLE:
xm.mark_step() xm.mark_step()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment