# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import html from typing import Callable, Dict, List, Optional, Union import numpy as np import regex as re import torch from torch.nn import functional as F from transformers import CLIPTextModel, CLIPTokenizer, Qwen2_5_VLForConditionalGeneration, Qwen2VLProcessor from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import VaeImageProcessor from ...loaders import KandinskyLoraLoaderMixin from ...models import AutoencoderKL from ...models.transformers import Kandinsky5Transformer3DModel from ...schedulers import FlowMatchEulerDiscreteScheduler # Add imports for offloading and tiling from ...utils import ( is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring, ) from ...utils.torch_utils import randn_tensor from ..pipeline_utils import DiffusionPipeline from .pipeline_output import KandinskyImagePipelineOutput if is_torch_xla_available(): import torch_xla.core.xla_model as xm XLA_AVAILABLE = True else: XLA_AVAILABLE = False logger = logging.get_logger(__name__) # pylint: disable=invalid-name if is_ftfy_available(): import ftfy logger = logging.get_logger(__name__) EXAMPLE_DOC_STRING = """ Examples: ```python >>> import torch >>> from diffusers import Kandinsky5T2IPipeline >>> # Available models: >>> # kandinskylab/Kandinsky-5.0-T2I-Lite-sft-Diffusers >>> # kandinskylab/Kandinsky-5.0-T2I-Lite-pretrain-Diffusers >>> model_id = "kandinskylab/Kandinsky-5.0-T2I-Lite-sft-Diffusers" >>> pipe = Kandinsky5T2IPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) >>> pipe = pipe.to("cuda") >>> prompt = "A cat and a dog baking a cake together in a kitchen." >>> output = pipe( ... prompt=prompt, ... negative_prompt="", ... height=1024, ... width=1024, ... num_inference_steps=50, ... guidance_scale=3.5, ... ).frames[0] ``` """ def basic_clean(text): """ Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan.py Clean text using ftfy if available and unescape HTML entities. """ if is_ftfy_available(): text = ftfy.fix_text(text) text = html.unescape(html.unescape(text)) return text.strip() def whitespace_clean(text): """ Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan.py Normalize whitespace in text by replacing multiple spaces with single space. """ text = re.sub(r"\s+", " ", text) text = text.strip() return text def prompt_clean(text): """ Copied from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/wan/pipeline_wan.py Apply both basic cleaning and whitespace normalization to prompts. """ text = whitespace_clean(basic_clean(text)) return text class Kandinsky5T2IPipeline(DiffusionPipeline, KandinskyLoraLoaderMixin): r""" Pipeline for text-to-image generation using Kandinsky 5.0. This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods implemented for all pipelines (downloading, saving, running on a particular device, etc.). Args: transformer ([`Kandinsky5Transformer3DModel`]): Conditional Transformer to denoise the encoded image latents. vae ([`AutoencoderKL`]): Variational Auto-Encoder Model [black-forest-labs/FLUX.1-dev (vae)](https://huggingface.co/black-forest-labs/FLUX.1-dev) to encode and decode videos to and from latent representations. text_encoder ([`Qwen2_5_VLForConditionalGeneration`]): Frozen text-encoder [Qwen2.5-VL](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct). tokenizer ([`AutoProcessor`]): Tokenizer for Qwen2.5-VL. text_encoder_2 ([`CLIPTextModel`]): Frozen [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. tokenizer_2 ([`CLIPTokenizer`]): Tokenizer for CLIP. scheduler ([`FlowMatchEulerDiscreteScheduler`]): A scheduler to be used in combination with `transformer` to denoise the encoded image latents. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae" _callback_tensor_inputs = [ "latents", "prompt_embeds_qwen", "prompt_embeds_clip", "negative_prompt_embeds_qwen", "negative_prompt_embeds_clip", ] def __init__( self, transformer: Kandinsky5Transformer3DModel, vae: AutoencoderKL, text_encoder: Qwen2_5_VLForConditionalGeneration, tokenizer: Qwen2VLProcessor, text_encoder_2: CLIPTextModel, tokenizer_2: CLIPTokenizer, scheduler: FlowMatchEulerDiscreteScheduler, ): super().__init__() self.register_modules( transformer=transformer, vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, text_encoder_2=text_encoder_2, tokenizer_2=tokenizer_2, scheduler=scheduler, ) self.prompt_template = "<|im_start|>system\nYou are a promt engineer. Describe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>" self.prompt_template_encode_start_idx = 41 self.vae_scale_factor_spatial = 8 self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor_spatial) self.resolutions = [(1024, 1024), (640, 1408), (1408, 640), (768, 1280), (1280, 768), (896, 1152), (1152, 896)] def _encode_prompt_qwen( self, prompt: List[str], device: Optional[torch.device] = None, max_sequence_length: int = 512, dtype: Optional[torch.dtype] = None, ): """ Encode prompt using Qwen2.5-VL text encoder. This method processes the input prompt through the Qwen2.5-VL model to generate text embeddings suitable for image generation. Args: prompt List[str]: Input list of prompts device (torch.device): Device to run encoding on max_sequence_length (int): Maximum sequence length for tokenization dtype (torch.dtype): Data type for embeddings Returns: Tuple[torch.Tensor, torch.Tensor]: Text embeddings and cumulative sequence lengths """ device = device or self._execution_device dtype = dtype or self.text_encoder.dtype full_texts = [self.prompt_template.format(p) for p in prompt] max_allowed_len = self.prompt_template_encode_start_idx + max_sequence_length untruncated_ids = self.tokenizer( text=full_texts, images=None, videos=None, return_tensors="pt", padding="longest", )["input_ids"] if untruncated_ids.shape[-1] > max_allowed_len: for i, text in enumerate(full_texts): tokens = untruncated_ids[i][self.prompt_template_encode_start_idx : -2] removed_text = self.tokenizer.decode(tokens[max_sequence_length - 2 :]) if len(removed_text) > 0: full_texts[i] = text[: -len(removed_text)] logger.warning( "The following part of your input was truncated because `max_sequence_length` is set to " f" {max_sequence_length} tokens: {removed_text}" ) inputs = self.tokenizer( text=full_texts, images=None, videos=None, max_length=max_allowed_len, truncation=True, return_tensors="pt", padding=True, ).to(device) embeds = self.text_encoder( input_ids=inputs["input_ids"], return_dict=True, output_hidden_states=True, )["hidden_states"][-1][:, self.prompt_template_encode_start_idx :] attention_mask = inputs["attention_mask"][:, self.prompt_template_encode_start_idx :] cu_seqlens = torch.cumsum(attention_mask.sum(1), dim=0) cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).to(dtype=torch.int32) return embeds.to(dtype), cu_seqlens def _encode_prompt_clip( self, prompt: Union[str, List[str]], device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): """ Encode prompt using CLIP text encoder. This method processes the input prompt through the CLIP model to generate pooled embeddings that capture semantic information. Args: prompt (Union[str, List[str]]): Input prompt or list of prompts device (torch.device): Device to run encoding on dtype (torch.dtype): Data type for embeddings Returns: torch.Tensor: Pooled text embeddings from CLIP """ device = device or self._execution_device dtype = dtype or self.text_encoder_2.dtype inputs = self.tokenizer_2( prompt, max_length=77, truncation=True, add_special_tokens=True, padding="max_length", return_tensors="pt", ).to(device) pooled_embed = self.text_encoder_2(**inputs)["pooler_output"] return pooled_embed.to(dtype) def encode_prompt( self, prompt: Union[str, List[str]], num_images_per_prompt: int = 1, max_sequence_length: int = 512, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, ): r""" Encodes a single prompt (positive or negative) into text encoder hidden states. This method combines embeddings from both Qwen2.5-VL and CLIP text encoders to create comprehensive text representations for image generation. Args: prompt (`str` or `List[str]`): Prompt to be encoded. num_images_per_prompt (`int`, *optional*, defaults to 1): Number of images to generate per prompt. max_sequence_length (`int`, *optional*, defaults to 512): Maximum sequence length for text encoding. Must be less than 1024 device (`torch.device`, *optional*): Torch device. dtype (`torch.dtype`, *optional*): Torch dtype. Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - Qwen text embeddings of shape (batch_size * num_images_per_prompt, sequence_length, embedding_dim) - CLIP pooled embeddings of shape (batch_size * num_images_per_prompt, clip_embedding_dim) - Cumulative sequence lengths (`cu_seqlens`) for Qwen embeddings of shape (batch_size * num_images_per_prompt + 1,) """ device = device or self._execution_device dtype = dtype or self.text_encoder.dtype if not isinstance(prompt, list): prompt = [prompt] batch_size = len(prompt) prompt = [prompt_clean(p) for p in prompt] # Encode with Qwen2.5-VL prompt_embeds_qwen, prompt_cu_seqlens = self._encode_prompt_qwen( prompt=prompt, device=device, max_sequence_length=max_sequence_length, dtype=dtype, ) # prompt_embeds_qwen shape: [batch_size, seq_len, embed_dim] # Encode with CLIP prompt_embeds_clip = self._encode_prompt_clip( prompt=prompt, device=device, dtype=dtype, ) # prompt_embeds_clip shape: [batch_size, clip_embed_dim] # Repeat embeddings for num_images_per_prompt # Qwen embeddings: repeat sequence for each image, then reshape prompt_embeds_qwen = prompt_embeds_qwen.repeat( 1, num_images_per_prompt, 1 ) # [batch_size, seq_len * num_images_per_prompt, embed_dim] # Reshape to [batch_size * num_images_per_prompt, seq_len, embed_dim] prompt_embeds_qwen = prompt_embeds_qwen.view( batch_size * num_images_per_prompt, -1, prompt_embeds_qwen.shape[-1] ) # CLIP embeddings: repeat for each image prompt_embeds_clip = prompt_embeds_clip.repeat( 1, num_images_per_prompt, 1 ) # [batch_size, num_images_per_prompt, clip_embed_dim] # Reshape to [batch_size * num_images_per_prompt, clip_embed_dim] prompt_embeds_clip = prompt_embeds_clip.view(batch_size * num_images_per_prompt, -1) # Repeat cumulative sequence lengths for num_images_per_prompt # Original differences (lengths) for each prompt in the batch original_lengths = prompt_cu_seqlens.diff() # [len1, len2, ...] # Repeat the lengths for num_images_per_prompt repeated_lengths = original_lengths.repeat_interleave( num_images_per_prompt ) # [len1, len1, ..., len2, len2, ...] # Reconstruct the cumulative lengths repeated_cu_seqlens = torch.cat( [torch.tensor([0], device=device, dtype=torch.int32), repeated_lengths.cumsum(0)] ) return prompt_embeds_qwen, prompt_embeds_clip, repeated_cu_seqlens def check_inputs( self, prompt, negative_prompt, height, width, prompt_embeds_qwen=None, prompt_embeds_clip=None, negative_prompt_embeds_qwen=None, negative_prompt_embeds_clip=None, prompt_cu_seqlens=None, negative_prompt_cu_seqlens=None, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): """ Validate input parameters for the pipeline. Args: prompt: Input prompt negative_prompt: Negative prompt for guidance height: Image height width: Image width prompt_embeds_qwen: Pre-computed Qwen prompt embeddings prompt_embeds_clip: Pre-computed CLIP prompt embeddings negative_prompt_embeds_qwen: Pre-computed Qwen negative prompt embeddings negative_prompt_embeds_clip: Pre-computed CLIP negative prompt embeddings prompt_cu_seqlens: Pre-computed cumulative sequence lengths for Qwen positive prompt negative_prompt_cu_seqlens: Pre-computed cumulative sequence lengths for Qwen negative prompt callback_on_step_end_tensor_inputs: Callback tensor inputs Raises: ValueError: If inputs are invalid """ if max_sequence_length is not None and max_sequence_length > 1024: raise ValueError("max_sequence_length must be less than 1024") if (width, height) not in self.resolutions: resolutions_str = ",".join([f"({w},{h})" for w, h in self.resolutions]) logger.warning( f"`height` and `width` have to be one of {resolutions_str}, but are {height} and {width}. Dimensions will be resized accordingly" ) if callback_on_step_end_tensor_inputs is not None and not all( k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs ): raise ValueError( f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" ) # Check for consistency within positive prompt embeddings and sequence lengths if prompt_embeds_qwen is not None or prompt_embeds_clip is not None or prompt_cu_seqlens is not None: if prompt_embeds_qwen is None or prompt_embeds_clip is None or prompt_cu_seqlens is None: raise ValueError( "If any of `prompt_embeds_qwen`, `prompt_embeds_clip`, or `prompt_cu_seqlens` is provided, " "all three must be provided." ) # Check for consistency within negative prompt embeddings and sequence lengths if ( negative_prompt_embeds_qwen is not None or negative_prompt_embeds_clip is not None or negative_prompt_cu_seqlens is not None ): if ( negative_prompt_embeds_qwen is None or negative_prompt_embeds_clip is None or negative_prompt_cu_seqlens is None ): raise ValueError( "If any of `negative_prompt_embeds_qwen`, `negative_prompt_embeds_clip`, or `negative_prompt_cu_seqlens` is provided, " "all three must be provided." ) # Check if prompt or embeddings are provided (either prompt or all required embedding components for positive) if prompt is None and prompt_embeds_qwen is None: raise ValueError( "Provide either `prompt` or `prompt_embeds_qwen` (and corresponding `prompt_embeds_clip` and `prompt_cu_seqlens`). Cannot leave all undefined." ) # Validate types for prompt and negative_prompt if provided if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") if negative_prompt is not None and ( not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list) ): raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") def prepare_latents( self, batch_size: int, num_channels_latents: int = 16, height: int = 1024, width: int = 1024, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Prepare initial latent variables for text-to-image generation. This method creates random noise latents Args: batch_size (int): Number of images to generate num_channels_latents (int): Number of channels in latent space height (int): Height of generated image width (int): Width of generated image dtype (torch.dtype): Data type for latents device (torch.device): Device to create latents on generator (torch.Generator): Random number generator latents (torch.Tensor): Pre-existing latents to use Returns: torch.Tensor: Prepared latent tensor """ if latents is not None: return latents.to(device=device, dtype=dtype) shape = ( batch_size, 1, int(height) // self.vae_scale_factor_spatial, int(width) // self.vae_scale_factor_spatial, num_channels_latents, ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) # Generate random noise latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) return latents @property def guidance_scale(self): """Get the current guidance scale value.""" return self._guidance_scale @property def num_timesteps(self): """Get the number of denoising timesteps.""" return self._num_timesteps @property def interrupt(self): """Check if generation has been interrupted.""" return self._interrupt @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, negative_prompt: Optional[Union[str, List[str]]] = None, height: int = 1024, width: int = 1024, num_inference_steps: int = 50, guidance_scale: float = 3.5, num_images_per_prompt: Optional[int] = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.Tensor] = None, prompt_embeds_qwen: Optional[torch.Tensor] = None, prompt_embeds_clip: Optional[torch.Tensor] = None, negative_prompt_embeds_qwen: Optional[torch.Tensor] = None, negative_prompt_embeds_clip: Optional[torch.Tensor] = None, prompt_cu_seqlens: Optional[torch.Tensor] = None, negative_prompt_cu_seqlens: Optional[torch.Tensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback_on_step_end: Optional[ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] ] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], max_sequence_length: int = 512, ): r""" The call function to the pipeline for text-to-image generation. Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead. negative_prompt (`str` or `List[str]`, *optional*): The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale` < `1`). height (`int`, defaults to `1024`): The height in pixels of the generated image. width (`int`, defaults to `1024`): The width in pixels of the generated image. num_inference_steps (`int`, defaults to `50`): The number of denoising steps. guidance_scale (`float`, defaults to `5.0`): Guidance scale as defined in classifier-free guidance. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. generator (`torch.Generator` or `List[torch.Generator]`, *optional*): A torch generator to make generation deterministic. latents (`torch.Tensor`, *optional*): Pre-generated noisy latents. prompt_embeds_qwen (`torch.Tensor`, *optional*): Pre-generated Qwen text embeddings. prompt_embeds_clip (`torch.Tensor`, *optional*): Pre-generated CLIP text embeddings. negative_prompt_embeds_qwen (`torch.Tensor`, *optional*): Pre-generated Qwen negative text embeddings. negative_prompt_embeds_clip (`torch.Tensor`, *optional*): Pre-generated CLIP negative text embeddings. prompt_cu_seqlens (`torch.Tensor`, *optional*): Pre-generated cumulative sequence lengths for Qwen positive prompt. negative_prompt_cu_seqlens (`torch.Tensor`, *optional*): Pre-generated cumulative sequence lengths for Qwen negative prompt. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`KandinskyImagePipelineOutput`]. callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): A function that is called at the end of each denoising step. callback_on_step_end_tensor_inputs (`List`, *optional*): The list of tensor inputs for the `callback_on_step_end` function. max_sequence_length (`int`, defaults to `512`): The maximum sequence length for text encoding. Examples: Returns: [`~KandinskyImagePipelineOutput`] or `tuple`: If `return_dict` is `True`, [`KandinskyImagePipelineOutput`] is returned, otherwise a `tuple` is returned where the first element is a list with the generated images. """ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs self.check_inputs( prompt=prompt, negative_prompt=negative_prompt, height=height, width=width, prompt_embeds_qwen=prompt_embeds_qwen, prompt_embeds_clip=prompt_embeds_clip, negative_prompt_embeds_qwen=negative_prompt_embeds_qwen, negative_prompt_embeds_clip=negative_prompt_embeds_clip, prompt_cu_seqlens=prompt_cu_seqlens, negative_prompt_cu_seqlens=negative_prompt_cu_seqlens, callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, ) if (width, height) not in self.resolutions: width, height = self.resolutions[ np.argmin([abs((i[0] / i[1]) - (width / height)) for i in self.resolutions]) ] self._guidance_scale = guidance_scale self._interrupt = False device = self._execution_device dtype = self.transformer.dtype # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 prompt = [prompt] elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds_qwen.shape[0] # 3. Encode input prompt if prompt_embeds_qwen is None: prompt_embeds_qwen, prompt_embeds_clip, prompt_cu_seqlens = self.encode_prompt( prompt=prompt, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype, ) if self.guidance_scale > 1.0: if negative_prompt is None: negative_prompt = "" if isinstance(negative_prompt, str): negative_prompt = [negative_prompt] * len(prompt) if prompt is not None else [negative_prompt] elif len(negative_prompt) != len(prompt): raise ValueError( f"`negative_prompt` must have same length as `prompt`. Got {len(negative_prompt)} vs {len(prompt)}." ) if negative_prompt_embeds_qwen is None: negative_prompt_embeds_qwen, negative_prompt_embeds_clip, negative_prompt_cu_seqlens = ( self.encode_prompt( prompt=negative_prompt, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype, ) ) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.transformer.config.in_visual_dim latents = self.prepare_latents( batch_size=batch_size * num_images_per_prompt, num_channels_latents=num_channels_latents, height=height, width=width, dtype=dtype, device=device, generator=generator, latents=latents, ) # 6. Prepare rope positions for positional encoding visual_rope_pos = [ torch.arange(1, device=device), torch.arange(height // self.vae_scale_factor_spatial // 2, device=device), torch.arange(width // self.vae_scale_factor_spatial // 2, device=device), ] text_rope_pos = torch.arange(prompt_cu_seqlens.diff().max().item(), device=device) negative_text_rope_pos = ( torch.arange(negative_prompt_cu_seqlens.diff().max().item(), device=device) if negative_prompt_cu_seqlens is not None else None ) # 7. Calculate dynamic scale factor based on resolution scale_factor = [1.0, 1.0, 1.0] # 8. Sparse Params for efficient attention sparse_params = None # 9. Denoising loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): if self.interrupt: continue timestep = t.unsqueeze(0).repeat(batch_size * num_images_per_prompt) # Predict noise residual pred_velocity = self.transformer( hidden_states=latents.to(dtype), encoder_hidden_states=prompt_embeds_qwen.to(dtype), pooled_projections=prompt_embeds_clip.to(dtype), timestep=timestep.to(dtype), visual_rope_pos=visual_rope_pos, text_rope_pos=text_rope_pos, scale_factor=scale_factor, sparse_params=sparse_params, return_dict=True, ).sample if self.guidance_scale > 1.0 and negative_prompt_embeds_qwen is not None: uncond_pred_velocity = self.transformer( hidden_states=latents.to(dtype), encoder_hidden_states=negative_prompt_embeds_qwen.to(dtype), pooled_projections=negative_prompt_embeds_clip.to(dtype), timestep=timestep.to(dtype), visual_rope_pos=visual_rope_pos, text_rope_pos=negative_text_rope_pos, scale_factor=scale_factor, sparse_params=sparse_params, return_dict=True, ).sample pred_velocity = uncond_pred_velocity + guidance_scale * (pred_velocity - uncond_pred_velocity) latents = self.scheduler.step(pred_velocity[:, :], t, latents, return_dict=False)[0] if callback_on_step_end is not None: callback_kwargs = {} for k in callback_on_step_end_tensor_inputs: callback_kwargs[k] = locals()[k] callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) latents = callback_outputs.pop("latents", latents) prompt_embeds_qwen = callback_outputs.pop("prompt_embeds_qwen", prompt_embeds_qwen) prompt_embeds_clip = callback_outputs.pop("prompt_embeds_clip", prompt_embeds_clip) negative_prompt_embeds_qwen = callback_outputs.pop( "negative_prompt_embeds_qwen", negative_prompt_embeds_qwen ) negative_prompt_embeds_clip = callback_outputs.pop( "negative_prompt_embeds_clip", negative_prompt_embeds_clip ) if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if XLA_AVAILABLE: xm.mark_step() # 9. Post-processing - extract main latents latents = latents[:, :, :, :, :num_channels_latents] # 10. Decode latents to image if output_type != "latent": latents = latents.to(self.vae.dtype) # Reshape and normalize latents latents = latents.reshape( batch_size, num_images_per_prompt, 1, height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial, num_channels_latents, ) latents = latents.permute(0, 1, 5, 2, 3, 4) # [batch, num_images, channels, 1, height, width] latents = latents.reshape( batch_size * num_images_per_prompt, num_channels_latents, height // self.vae_scale_factor_spatial, width // self.vae_scale_factor_spatial, ) # Normalize and decode through VAE latents = latents / self.vae.config.scaling_factor image = self.vae.decode(latents).sample image = self.image_processor.postprocess(image, output_type=output_type) else: image = latents # Offload all models self.maybe_free_model_hooks() if not return_dict: return (image,) return KandinskyImagePipelineOutput(image=image)