Unverified Commit 9a38fab5 authored by Aryan's avatar Aryan Committed by GitHub
Browse files

tests + minor refactor for QwenImage (#12057)

* update

* update

* update

* add docs
parent cb8e61ed
...@@ -366,6 +366,8 @@ ...@@ -366,6 +366,8 @@
title: PixArtTransformer2DModel title: PixArtTransformer2DModel
- local: api/models/prior_transformer - local: api/models/prior_transformer
title: PriorTransformer title: PriorTransformer
- local: api/models/qwenimage_transformer2d
title: QwenImageTransformer2DModel
- local: api/models/sana_transformer2d - local: api/models/sana_transformer2d
title: SanaTransformer2DModel title: SanaTransformer2DModel
- local: api/models/sd3_transformer2d - local: api/models/sd3_transformer2d
...@@ -418,6 +420,8 @@ ...@@ -418,6 +420,8 @@
title: AutoencoderKLMagvit title: AutoencoderKLMagvit
- local: api/models/autoencoderkl_mochi - local: api/models/autoencoderkl_mochi
title: AutoencoderKLMochi title: AutoencoderKLMochi
- local: api/models/autoencoderkl_qwenimage
title: AutoencoderKLQwenImage
- local: api/models/autoencoder_kl_wan - local: api/models/autoencoder_kl_wan
title: AutoencoderKLWan title: AutoencoderKLWan
- local: api/models/consistency_decoder_vae - local: api/models/consistency_decoder_vae
...@@ -554,6 +558,8 @@ ...@@ -554,6 +558,8 @@
title: PixArt-α title: PixArt-α
- local: api/pipelines/pixart_sigma - local: api/pipelines/pixart_sigma
title: PixArt-Σ title: PixArt-Σ
- local: api/pipelines/qwenimage
title: QwenImage
- local: api/pipelines/sana - local: api/pipelines/sana
title: Sana title: Sana
- local: api/pipelines/sana_sprint - local: api/pipelines/sana_sprint
......
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLQwenImage
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLQwenImage
vae = AutoencoderKLQwenImage.from_pretrained("Qwen/QwenImage-20B", subfolder="vae")
```
## AutoencoderKLQwenImage
[[autodoc]] AutoencoderKLQwenImage
- decode
- encode
- all
## AutoencoderKLOutput
[[autodoc]] models.autoencoders.autoencoder_kl.AutoencoderKLOutput
## DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# QwenImageTransformer2DModel
The model can be loaded with the following code snippet.
```python
from diffusers import QwenImageTransformer2DModel
transformer = QwenImageTransformer2DModel.from_pretrained("Qwen/QwenImage-20B", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## QwenImageTransformer2DModel
[[autodoc]] QwenImageTransformer2DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
# QwenImage
<!-- TODO: update this section when model is out -->
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## QwenImagePipeline
[[autodoc]] QwenImagePipeline
- all
- __call__
## QwenImagePipeline
[[autodoc]] pipelines.qwenimage.pipeline_output.QwenImagePipelineOutput
...@@ -668,6 +668,7 @@ class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -668,6 +668,7 @@ class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
_supports_gradient_checkpointing = False _supports_gradient_checkpointing = False
# fmt: off
@register_to_config @register_to_config
def __init__( def __init__(
self, self,
...@@ -678,43 +679,10 @@ class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin): ...@@ -678,43 +679,10 @@ class AutoencoderKLQwenImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
attn_scales: List[float] = [], attn_scales: List[float] = [],
temperal_downsample: List[bool] = [False, True, True], temperal_downsample: List[bool] = [False, True, True],
dropout: float = 0.0, dropout: float = 0.0,
latents_mean: List[float] = [ latents_mean: List[float] = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921],
-0.7571, latents_std: List[float] = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160],
-0.7089,
-0.9113,
0.1075,
-0.1745,
0.9653,
-0.1517,
1.5508,
0.4134,
-0.0715,
0.5517,
-0.3632,
-0.1922,
-0.9497,
0.2503,
-0.2921,
],
latents_std: List[float] = [
2.8184,
1.4541,
2.3275,
2.6558,
1.2196,
1.7708,
2.6052,
2.0743,
3.2687,
2.1526,
2.8652,
1.5579,
1.6382,
1.1253,
2.8251,
1.9160,
],
) -> None: ) -> None:
# fmt: on
super().__init__() super().__init__()
self.z_dim = z_dim self.z_dim = z_dim
......
...@@ -140,7 +140,7 @@ def apply_rotary_emb_qwen( ...@@ -140,7 +140,7 @@ def apply_rotary_emb_qwen(
class QwenTimestepProjEmbeddings(nn.Module): class QwenTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, pooled_projection_dim): def __init__(self, embedding_dim):
super().__init__() super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000) self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
...@@ -473,8 +473,6 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro ...@@ -473,8 +473,6 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
joint_attention_dim (`int`, defaults to `3584`): joint_attention_dim (`int`, defaults to `3584`):
The number of dimensions to use for the joint attention (embedding/channel dimension of The number of dimensions to use for the joint attention (embedding/channel dimension of
`encoder_hidden_states`). `encoder_hidden_states`).
pooled_projection_dim (`int`, defaults to `768`):
The number of dimensions to use for the pooled projection.
guidance_embeds (`bool`, defaults to `False`): guidance_embeds (`bool`, defaults to `False`):
Whether to use guidance embeddings for guidance-distilled variant of the model. Whether to use guidance embeddings for guidance-distilled variant of the model.
axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`): axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
...@@ -495,8 +493,7 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro ...@@ -495,8 +493,7 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
attention_head_dim: int = 128, attention_head_dim: int = 128,
num_attention_heads: int = 24, num_attention_heads: int = 24,
joint_attention_dim: int = 3584, joint_attention_dim: int = 3584,
pooled_projection_dim: int = 768, guidance_embeds: bool = False, # TODO: this should probably be removed
guidance_embeds: bool = False,
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56), axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
): ):
super().__init__() super().__init__()
...@@ -505,9 +502,7 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro ...@@ -505,9 +502,7 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True) self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
self.time_text_embed = QwenTimestepProjEmbeddings( self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim)
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
)
self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6) self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)
...@@ -538,10 +533,9 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro ...@@ -538,10 +533,9 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
timestep: torch.LongTensor = None, timestep: torch.LongTensor = None,
img_shapes: Optional[List[Tuple[int, int, int]]] = None, img_shapes: Optional[List[Tuple[int, int, int]]] = None,
txt_seq_lens: Optional[List[int]] = None, txt_seq_lens: Optional[List[int]] = None,
guidance: torch.Tensor = None, guidance: torch.Tensor = None, # TODO: this should probably be removed
joint_attention_kwargs: Optional[Dict[str, Any]] = None, attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True, return_dict: bool = True,
controlnet_blocks_repeat: bool = False,
) -> Union[torch.Tensor, Transformer2DModelOutput]: ) -> Union[torch.Tensor, Transformer2DModelOutput]:
""" """
The [`QwenTransformer2DModel`] forward method. The [`QwenTransformer2DModel`] forward method.
...@@ -555,7 +549,7 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro ...@@ -555,7 +549,7 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
Mask of the input conditions. Mask of the input conditions.
timestep ( `torch.LongTensor`): timestep ( `torch.LongTensor`):
Used to indicate denoising step. Used to indicate denoising step.
joint_attention_kwargs (`dict`, *optional*): attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in `self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
...@@ -567,9 +561,9 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro ...@@ -567,9 +561,9 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor. `tuple` where the first element is the sample tensor.
""" """
if joint_attention_kwargs is not None: if attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy() attention_kwargs = attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0) lora_scale = attention_kwargs.pop("scale", 1.0)
else: else:
lora_scale = 1.0 lora_scale = 1.0
...@@ -577,7 +571,7 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro ...@@ -577,7 +571,7 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
# weight the lora layers by setting `lora_scale` for each PEFT layer # weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale) scale_lora_layers(self, lora_scale)
else: else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning( logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
) )
...@@ -617,7 +611,7 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro ...@@ -617,7 +611,7 @@ class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Fro
encoder_hidden_states_mask=encoder_hidden_states_mask, encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb, temb=temb,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs, joint_attention_kwargs=attention_kwargs,
) )
# Use only the image part (hidden_states) from the dual-stream blocks # Use only the image part (hidden_states) from the dual-stream blocks
......
...@@ -17,19 +17,12 @@ from typing import Any, Callable, Dict, List, Optional, Union ...@@ -17,19 +17,12 @@ from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
from transformers import ( from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
Qwen2_5_VLForConditionalGeneration,
Qwen2Tokenizer,
)
from ...image_processor import VaeImageProcessor from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel from ...models import AutoencoderKLQwenImage, QwenImageTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import ( from ...utils import is_torch_xla_available, logging, replace_example_docstring
is_torch_xla_available,
logging,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import QwenImagePipelineOutput from .pipeline_output import QwenImagePipelineOutput
...@@ -135,9 +128,7 @@ def retrieve_timesteps( ...@@ -135,9 +128,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps return timesteps, num_inference_steps
class QwenImagePipeline( class QwenImagePipeline(DiffusionPipeline):
DiffusionPipeline,
):
r""" r"""
The QwenImage pipeline for text-to-image generation. The QwenImage pipeline for text-to-image generation.
...@@ -157,7 +148,6 @@ class QwenImagePipeline( ...@@ -157,7 +148,6 @@ class QwenImagePipeline(
""" """
model_cpu_offload_seq = "text_encoder->transformer->vae" model_cpu_offload_seq = "text_encoder->transformer->vae"
_optional_components = ["image_encoder", "feature_extractor"]
_callback_tensor_inputs = ["latents", "prompt_embeds"] _callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__( def __init__(
...@@ -186,13 +176,10 @@ class QwenImagePipeline( ...@@ -186,13 +176,10 @@ class QwenImagePipeline(
self.prompt_template_encode_start_idx = 34 self.prompt_template_encode_start_idx = 34
self.default_sample_size = 128 self.default_sample_size = 128
def extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor): def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
bool_mask = mask.bool() bool_mask = mask.bool()
valid_lengths = bool_mask.sum(dim=1) valid_lengths = bool_mask.sum(dim=1)
selected = hidden_states[bool_mask] selected = hidden_states[bool_mask]
split_result = torch.split(selected, valid_lengths.tolist(), dim=0) split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
return split_result return split_result
...@@ -200,8 +187,6 @@ class QwenImagePipeline( ...@@ -200,8 +187,6 @@ class QwenImagePipeline(
def _get_qwen_prompt_embeds( def _get_qwen_prompt_embeds(
self, self,
prompt: Union[str, List[str]] = None, prompt: Union[str, List[str]] = None,
num_images_per_prompt: int = 1,
max_sequence_length: int = 1024,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
): ):
...@@ -209,7 +194,6 @@ class QwenImagePipeline( ...@@ -209,7 +194,6 @@ class QwenImagePipeline(
dtype = dtype or self.text_encoder.dtype dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
template = self.prompt_template_encode template = self.prompt_template_encode
drop_idx = self.prompt_template_encode_start_idx drop_idx = self.prompt_template_encode_start_idx
...@@ -223,7 +207,7 @@ class QwenImagePipeline( ...@@ -223,7 +207,7 @@ class QwenImagePipeline(
output_hidden_states=True, output_hidden_states=True,
) )
hidden_states = encoder_hidden_states.hidden_states[-1] hidden_states = encoder_hidden_states.hidden_states[-1]
split_hidden_states = self.extract_masked_hidden(hidden_states, txt_tokens.attention_mask) split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states] split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states] attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states]) max_seq_len = max([e.size(0) for e in split_hidden_states])
...@@ -234,18 +218,8 @@ class QwenImagePipeline( ...@@ -234,18 +218,8 @@ class QwenImagePipeline(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list] [torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
) )
dtype = self.text_encoder.dtype
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
_, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
encoder_attention_mask = encoder_attention_mask.repeat(1, num_images_per_prompt, 1)
encoder_attention_mask = encoder_attention_mask.view(batch_size * num_images_per_prompt, seq_len)
return prompt_embeds, encoder_attention_mask return prompt_embeds, encoder_attention_mask
def encode_prompt( def encode_prompt(
...@@ -253,8 +227,8 @@ class QwenImagePipeline( ...@@ -253,8 +227,8 @@ class QwenImagePipeline(
prompt: Union[str, List[str]], prompt: Union[str, List[str]],
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_mask: Optional[torch.FloatTensor] = None, prompt_embeds_mask: Optional[torch.Tensor] = None,
max_sequence_length: int = 1024, max_sequence_length: int = 1024,
): ):
r""" r"""
...@@ -262,38 +236,29 @@ class QwenImagePipeline( ...@@ -262,38 +236,29 @@ class QwenImagePipeline(
Args: Args:
prompt (`str` or `List[str]`, *optional*): prompt (`str` or `List[str]`, *optional*):
prompt to be encoded prompt to be encoded
prompt_2 (`str` or `List[str]`, *optional*):
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
used in all text-encoders
device: (`torch.device`): device: (`torch.device`):
torch device torch device
num_images_per_prompt (`int`): num_images_per_prompt (`int`):
number of images that should be generated per prompt number of images that should be generated per prompt
prompt_embeds (`torch.FloatTensor`, *optional*): prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument. provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
lora_scale (`float`, *optional*):
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
""" """
device = device or self._execution_device device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
if prompt_embeds is None: if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds( prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
prompt=prompt,
device=device,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
)
dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype _, seq_len, _ = prompt_embeds.shape
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
return prompt_embeds, prompt_embeds_mask, text_ids return prompt_embeds, prompt_embeds_mask
def check_inputs( def check_inputs(
self, self,
...@@ -457,8 +422,8 @@ class QwenImagePipeline( ...@@ -457,8 +422,8 @@ class QwenImagePipeline(
return self._guidance_scale return self._guidance_scale
@property @property
def joint_attention_kwargs(self): def attention_kwargs(self):
return self._joint_attention_kwargs return self._attention_kwargs
@property @property
def num_timesteps(self): def num_timesteps(self):
...@@ -486,14 +451,14 @@ class QwenImagePipeline( ...@@ -486,14 +451,14 @@ class QwenImagePipeline(
guidance_scale: float = 1.0, guidance_scale: float = 1.0,
num_images_per_prompt: int = 1, num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None, latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_mask: Optional[torch.FloatTensor] = None, prompt_embeds_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds_mask: Optional[torch.FloatTensor] = None, negative_prompt_embeds_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil", output_type: Optional[str] = "pil",
return_dict: bool = True, return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None, attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"], callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512, max_sequence_length: int = 512,
...@@ -533,41 +498,23 @@ class QwenImagePipeline( ...@@ -533,41 +498,23 @@ class QwenImagePipeline(
generator (`torch.Generator` or `List[torch.Generator]`, *optional*): generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic. to make generation deterministic.
latents (`torch.FloatTensor`, *optional*): latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will be generated by sampling using the supplied random `generator`. tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*): prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument. provided, text embeddings will be generated from `prompt` input argument.
pooled_prompt_embeds (`torch.FloatTensor`, *optional*): negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
If not provided, pooled text embeddings will be generated from `prompt` input argument.
ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
negative_ip_adapter_image:
(`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
provided, embeddings are computed from the `ip_adapter_image` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument. argument.
negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
output_type (`str`, *optional*, defaults to `"pil"`): output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`): return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple. Whether or not to return a [`~pipelines.qwenimage.QwenImagePipelineOutput`] instead of a plain tuple.
joint_attention_kwargs (`dict`, *optional*): attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in `self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
...@@ -608,7 +555,7 @@ class QwenImagePipeline( ...@@ -608,7 +555,7 @@ class QwenImagePipeline(
) )
self._guidance_scale = guidance_scale self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs self._attention_kwargs = attention_kwargs
self._current_timestep = None self._current_timestep = None
self._interrupt = False self._interrupt = False
...@@ -626,11 +573,7 @@ class QwenImagePipeline( ...@@ -626,11 +573,7 @@ class QwenImagePipeline(
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
) )
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
( prompt_embeds, prompt_embeds_mask = self.encode_prompt(
prompt_embeds,
prompt_embeds_mask,
text_ids,
) = self.encode_prompt(
prompt=prompt, prompt=prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
prompt_embeds_mask=prompt_embeds_mask, prompt_embeds_mask=prompt_embeds_mask,
...@@ -639,11 +582,7 @@ class QwenImagePipeline( ...@@ -639,11 +582,7 @@ class QwenImagePipeline(
max_sequence_length=max_sequence_length, max_sequence_length=max_sequence_length,
) )
if do_true_cfg: if do_true_cfg:
( negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(
negative_prompt_embeds,
negative_prompt_embeds_mask,
negative_text_ids,
) = self.encode_prompt(
prompt=negative_prompt, prompt=negative_prompt,
prompt_embeds=negative_prompt_embeds, prompt_embeds=negative_prompt_embeds,
prompt_embeds_mask=negative_prompt_embeds_mask, prompt_embeds_mask=negative_prompt_embeds_mask,
...@@ -686,8 +625,6 @@ class QwenImagePipeline( ...@@ -686,8 +625,6 @@ class QwenImagePipeline(
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
# print(f"timesteps: {timesteps}")
# handle guidance # handle guidance
if self.transformer.config.guidance_embeds: if self.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32) guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
...@@ -695,8 +632,8 @@ class QwenImagePipeline( ...@@ -695,8 +632,8 @@ class QwenImagePipeline(
else: else:
guidance = None guidance = None
if self.joint_attention_kwargs is None: if self.attention_kwargs is None:
self._joint_attention_kwargs = {} self._attention_kwargs = {}
# 6. Denoising loop # 6. Denoising loop
self.scheduler.set_begin_index(0) self.scheduler.set_begin_index(0)
...@@ -717,7 +654,7 @@ class QwenImagePipeline( ...@@ -717,7 +654,7 @@ class QwenImagePipeline(
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
img_shapes=img_shapes, img_shapes=img_shapes,
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(), txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
joint_attention_kwargs=self.joint_attention_kwargs, attention_kwargs=self.attention_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
...@@ -731,7 +668,7 @@ class QwenImagePipeline( ...@@ -731,7 +668,7 @@ class QwenImagePipeline(
encoder_hidden_states=negative_prompt_embeds, encoder_hidden_states=negative_prompt_embeds,
img_shapes=img_shapes, img_shapes=img_shapes,
txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(), txt_seq_lens=negative_prompt_embeds_mask.sum(dim=1).tolist(),
joint_attention_kwargs=self.joint_attention_kwargs, attention_kwargs=self.attention_kwargs,
return_dict=False, return_dict=False,
)[0] )[0]
comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) comb_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
......
# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import torch
from transformers import Qwen2_5_VLConfig, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
from diffusers import (
AutoencoderKLQwenImage,
FlowMatchEulerDiscreteScheduler,
QwenImagePipeline,
QwenImageTransformer2DModel,
)
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin, to_np
enable_full_determinism()
class QwenImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = QwenImagePipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
supports_dduf = False
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
def get_dummy_components(self):
torch.manual_seed(0)
transformer = QwenImageTransformer2DModel(
patch_size=2,
in_channels=16,
out_channels=4,
num_layers=2,
attention_head_dim=16,
num_attention_heads=3,
joint_attention_dim=16,
guidance_embeds=False,
axes_dims_rope=(8, 4, 4),
)
torch.manual_seed(0)
z_dim = 4
vae = AutoencoderKLQwenImage(
base_dim=z_dim * 6,
z_dim=z_dim,
dim_mult=[1, 2, 4],
num_res_blocks=1,
temperal_downsample=[False, True],
# fmt: off
latents_mean=[0.0] * 4,
latents_std=[1.0] * 4,
# fmt: on
)
torch.manual_seed(0)
scheduler = FlowMatchEulerDiscreteScheduler()
torch.manual_seed(0)
config = Qwen2_5_VLConfig(
text_config={
"hidden_size": 16,
"intermediate_size": 16,
"num_hidden_layers": 2,
"num_attention_heads": 2,
"num_key_value_heads": 2,
"rope_scaling": {
"mrope_section": [1, 1, 2],
"rope_type": "default",
"type": "default",
},
"rope_theta": 1000000.0,
},
vision_config={
"depth": 2,
"hidden_size": 16,
"intermediate_size": 16,
"num_heads": 2,
"out_hidden_size": 16,
},
hidden_size=16,
vocab_size=152064,
vision_end_token_id=151653,
vision_start_token_id=151652,
vision_token_id=151654,
)
text_encoder = Qwen2_5_VLForConditionalGeneration(config)
tokenizer = Qwen2Tokenizer.from_pretrained("hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration")
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "dance monkey",
"negative_prompt": "bad quality",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 3.0,
"true_cfg_scale": 1.0,
"height": 32,
"width": 32,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
generated_image = image[0]
self.assertEqual(generated_image.shape, (3, 32, 32))
# fmt: off
expected_slice = torch.tensor([0.563, 0.6358, 0.6028, 0.5656, 0.5806, 0.5512, 0.5712, 0.6331, 0.4147, 0.3558, 0.5625, 0.4831, 0.4957, 0.5258, 0.4075, 0.5018])
# fmt: on
generated_slice = generated_image.flatten()
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-1)
def test_attention_slicing_forward_pass(
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
):
if not self.test_attention_slicing:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
output_without_slicing = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=1)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing1 = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=2)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing2 = pipe(**inputs)[0]
if test_max_difference:
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
self.assertLess(
max(max_diff1, max_diff2),
expected_max_diff,
"Attention slicing should not affect the inference results",
)
def test_vae_tiling(self, expected_diff_max: float = 0.2):
generator_device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to("cpu")
pipe.set_progress_bar_config(disable=None)
# Without tiling
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_without_tiling = pipe(**inputs)[0]
# With tiling
pipe.vae.enable_tiling(
tile_sample_min_height=96,
tile_sample_min_width=96,
tile_sample_stride_height=64,
tile_sample_stride_width=64,
)
inputs = self.get_dummy_inputs(generator_device)
inputs["height"] = inputs["width"] = 128
output_with_tiling = pipe(**inputs)[0]
self.assertLess(
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
expected_diff_max,
"VAE tiling should not affect the inference results",
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment