Unverified Commit e682af20 authored by naykun's avatar naykun Committed by GitHub
Browse files

Qwen Image Edit Support (#12164)

* feat(qwen-image):
add qwen-image-edit support

* fix(qwen image):
- compatible with torch.compile in new rope setting
- fix init import
- add prompt truncation in img2img and inpaint pipe
- remove unused logic and comment
- add copy statement
- guard logic for rope video shape tuple

* fix(qwen image):
- make fix-copies
- update doc
parent a58a4f66
...@@ -492,6 +492,7 @@ else: ...@@ -492,6 +492,7 @@ else:
"QwenImageImg2ImgPipeline", "QwenImageImg2ImgPipeline",
"QwenImageInpaintPipeline", "QwenImageInpaintPipeline",
"QwenImagePipeline", "QwenImagePipeline",
"QwenImageEditPipeline",
"ReduxImageEncoder", "ReduxImageEncoder",
"SanaControlNetPipeline", "SanaControlNetPipeline",
"SanaPAGPipeline", "SanaPAGPipeline",
...@@ -1123,6 +1124,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -1123,6 +1124,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
PixArtAlphaPipeline, PixArtAlphaPipeline,
PixArtSigmaPAGPipeline, PixArtSigmaPAGPipeline,
PixArtSigmaPipeline, PixArtSigmaPipeline,
QwenImageEditPipeline,
QwenImageImg2ImgPipeline, QwenImageImg2ImgPipeline,
QwenImageInpaintPipeline, QwenImageInpaintPipeline,
QwenImagePipeline, QwenImagePipeline,
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import functools import functools
import math import math
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
...@@ -161,9 +160,9 @@ class QwenEmbedRope(nn.Module): ...@@ -161,9 +160,9 @@ class QwenEmbedRope(nn.Module):
super().__init__() super().__init__()
self.theta = theta self.theta = theta
self.axes_dim = axes_dim self.axes_dim = axes_dim
pos_index = torch.arange(1024) pos_index = torch.arange(4096)
neg_index = torch.arange(1024).flip(0) * -1 - 1 neg_index = torch.arange(4096).flip(0) * -1 - 1
pos_freqs = torch.cat( self.pos_freqs = torch.cat(
[ [
self.rope_params(pos_index, self.axes_dim[0], self.theta), self.rope_params(pos_index, self.axes_dim[0], self.theta),
self.rope_params(pos_index, self.axes_dim[1], self.theta), self.rope_params(pos_index, self.axes_dim[1], self.theta),
...@@ -171,7 +170,7 @@ class QwenEmbedRope(nn.Module): ...@@ -171,7 +170,7 @@ class QwenEmbedRope(nn.Module):
], ],
dim=1, dim=1,
) )
neg_freqs = torch.cat( self.neg_freqs = torch.cat(
[ [
self.rope_params(neg_index, self.axes_dim[0], self.theta), self.rope_params(neg_index, self.axes_dim[0], self.theta),
self.rope_params(neg_index, self.axes_dim[1], self.theta), self.rope_params(neg_index, self.axes_dim[1], self.theta),
...@@ -180,10 +179,8 @@ class QwenEmbedRope(nn.Module): ...@@ -180,10 +179,8 @@ class QwenEmbedRope(nn.Module):
dim=1, dim=1,
) )
self.rope_cache = {} self.rope_cache = {}
self.register_buffer("pos_freqs", pos_freqs, persistent=False)
self.register_buffer("neg_freqs", neg_freqs, persistent=False)
# 是否使用 scale rope # DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
self.scale_rope = scale_rope self.scale_rope = scale_rope
def rope_params(self, index, dim, theta=10000): def rope_params(self, index, dim, theta=10000):
...@@ -201,35 +198,47 @@ class QwenEmbedRope(nn.Module): ...@@ -201,35 +198,47 @@ class QwenEmbedRope(nn.Module):
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args: Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
txt_length: [bs] a list of 1 integers representing the length of the text txt_length: [bs] a list of 1 integers representing the length of the text
""" """
if self.pos_freqs.device != device:
self.pos_freqs = self.pos_freqs.to(device)
self.neg_freqs = self.neg_freqs.to(device)
if isinstance(video_fhw, list): if isinstance(video_fhw, list):
video_fhw = video_fhw[0] video_fhw = video_fhw[0]
frame, height, width = video_fhw if not isinstance(video_fhw, list):
rope_key = f"{frame}_{height}_{width}" video_fhw = [video_fhw]
if not torch.compiler.is_compiling(): vid_freqs = []
if rope_key not in self.rope_cache: max_vid_index = 0
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width) for idx, fhw in enumerate(video_fhw):
vid_freqs = self.rope_cache[rope_key] frame, height, width = fhw
else: rope_key = f"{idx}_{height}_{width}"
vid_freqs = self._compute_video_freqs(frame, height, width)
if not torch.compiler.is_compiling():
if rope_key not in self.rope_cache:
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx)
video_freq = self.rope_cache[rope_key]
else:
video_freq = self._compute_video_freqs(frame, height, width, idx)
vid_freqs.append(video_freq)
if self.scale_rope: if self.scale_rope:
max_vid_index = max(height // 2, width // 2) max_vid_index = max(height // 2, width // 2, max_vid_index)
else: else:
max_vid_index = max(height, width) max_vid_index = max(height, width, max_vid_index)
max_len = max(txt_seq_lens) max_len = max(txt_seq_lens)
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...] txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
vid_freqs = torch.cat(vid_freqs, dim=0)
return vid_freqs, txt_freqs return vid_freqs, txt_freqs
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def _compute_video_freqs(self, frame, height, width): def _compute_video_freqs(self, frame, height, width, idx=0):
seq_lens = frame * height * width seq_lens = frame * height * width
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1) freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
freqs_frame = freqs_pos[0][:frame].view(frame, 1, 1, -1).expand(frame, height, width, -1) freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
if self.scale_rope: if self.scale_rope:
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0) freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1) freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
......
...@@ -391,6 +391,7 @@ else: ...@@ -391,6 +391,7 @@ else:
"QwenImagePipeline", "QwenImagePipeline",
"QwenImageImg2ImgPipeline", "QwenImageImg2ImgPipeline",
"QwenImageInpaintPipeline", "QwenImageInpaintPipeline",
"QwenImageEditPipeline",
] ]
try: try:
if not is_onnx_available(): if not is_onnx_available():
...@@ -708,7 +709,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -708,7 +709,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .paint_by_example import PaintByExamplePipeline from .paint_by_example import PaintByExamplePipeline
from .pia import PIAPipeline from .pia import PIAPipeline
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
from .qwenimage import QwenImageImg2ImgPipeline, QwenImageInpaintPipeline, QwenImagePipeline from .qwenimage import (
QwenImageEditPipeline,
QwenImageImg2ImgPipeline,
QwenImageInpaintPipeline,
QwenImagePipeline,
)
from .sana import SanaControlNetPipeline, SanaPipeline, SanaSprintImg2ImgPipeline, SanaSprintPipeline from .sana import SanaControlNetPipeline, SanaPipeline, SanaSprintImg2ImgPipeline, SanaSprintPipeline
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
......
...@@ -26,6 +26,7 @@ else: ...@@ -26,6 +26,7 @@ else:
_import_structure["pipeline_qwenimage"] = ["QwenImagePipeline"] _import_structure["pipeline_qwenimage"] = ["QwenImagePipeline"]
_import_structure["pipeline_qwenimage_img2img"] = ["QwenImageImg2ImgPipeline"] _import_structure["pipeline_qwenimage_img2img"] = ["QwenImageImg2ImgPipeline"]
_import_structure["pipeline_qwenimage_inpaint"] = ["QwenImageInpaintPipeline"] _import_structure["pipeline_qwenimage_inpaint"] = ["QwenImageInpaintPipeline"]
_import_structure["pipeline_qwenimage_edit"] = ["QwenImageEditPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try: try:
...@@ -35,6 +36,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -35,6 +36,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else: else:
from .pipeline_qwenimage import QwenImagePipeline from .pipeline_qwenimage import QwenImagePipeline
from .pipeline_qwenimage_edit import QwenImageEditPipeline
from .pipeline_qwenimage_img2img import QwenImageImg2ImgPipeline from .pipeline_qwenimage_img2img import QwenImageImg2ImgPipeline
from .pipeline_qwenimage_inpaint import QwenImageInpaintPipeline from .pipeline_qwenimage_inpaint import QwenImageInpaintPipeline
else: else:
......
...@@ -253,6 +253,9 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): ...@@ -253,6 +253,9 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
if prompt_embeds is None: if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
prompt_embeds = prompt_embeds[:, :max_sequence_length]
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
_, seq_len, _ = prompt_embeds.shape _, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
...@@ -316,20 +319,6 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): ...@@ -316,20 +319,6 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
if max_sequence_length is not None and max_sequence_length > 1024: if max_sequence_length is not None and max_sequence_length > 1024:
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
@staticmethod
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.reshape(
latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)
@staticmethod @staticmethod
def _pack_latents(latents, batch_size, num_channels_latents, height, width): def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
...@@ -402,8 +391,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): ...@@ -402,8 +391,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
shape = (batch_size, 1, num_channels_latents, height, width) shape = (batch_size, 1, num_channels_latents, height, width)
if latents is not None: if latents is not None:
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) return latents.to(device=device, dtype=dtype)
return latents.to(device=device, dtype=dtype), latent_image_ids
if isinstance(generator, list) and len(generator) != batch_size: if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError( raise ValueError(
...@@ -414,9 +402,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): ...@@ -414,9 +402,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) return latents
return latents, latent_image_ids
@property @property
def guidance_scale(self): def guidance_scale(self):
...@@ -594,7 +580,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): ...@@ -594,7 +580,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
# 4. Prepare latent variables # 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4 num_channels_latents = self.transformer.config.in_channels // 4
latents, latent_image_ids = self.prepare_latents( latents = self.prepare_latents(
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
num_channels_latents, num_channels_latents,
height, height,
...@@ -604,7 +590,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): ...@@ -604,7 +590,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
generator, generator,
latents, latents,
) )
img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size
# 5. Prepare timesteps # 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
......
This diff is collapsed.
...@@ -296,6 +296,9 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): ...@@ -296,6 +296,9 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
if prompt_embeds is None: if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
prompt_embeds = prompt_embeds[:, :max_sequence_length]
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
_, seq_len, _ = prompt_embeds.shape _, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
...@@ -363,21 +366,6 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): ...@@ -363,21 +366,6 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
if max_sequence_length is not None and max_sequence_length > 1024: if max_sequence_length is not None and max_sequence_length > 1024:
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
@staticmethod
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._prepare_latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.reshape(
latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)
@staticmethod @staticmethod
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
def _pack_latents(latents, batch_size, num_channels_latents, height, width): def _pack_latents(latents, batch_size, num_channels_latents, height, width):
...@@ -465,8 +453,7 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): ...@@ -465,8 +453,7 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.") raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.")
if latents is not None: if latents is not None:
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) return latents.to(device=device, dtype=dtype)
return latents.to(device=device, dtype=dtype), latent_image_ids
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
if image.shape[1] != self.latent_channels: if image.shape[1] != self.latent_channels:
...@@ -489,9 +476,7 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): ...@@ -489,9 +476,7 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
latents = self.scheduler.scale_noise(image_latents, timestep, noise) latents = self.scheduler.scale_noise(image_latents, timestep, noise)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) return latents
return latents, latent_image_ids
@property @property
def guidance_scale(self): def guidance_scale(self):
...@@ -713,7 +698,7 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): ...@@ -713,7 +698,7 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
# 5. Prepare latent variables # 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4 num_channels_latents = self.transformer.config.in_channels // 4
latents, latent_image_ids = self.prepare_latents( latents = self.prepare_latents(
init_image, init_image,
latent_timestep, latent_timestep,
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
...@@ -725,7 +710,7 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): ...@@ -725,7 +710,7 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
generator, generator,
latents, latents,
) )
img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
......
...@@ -307,6 +307,9 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): ...@@ -307,6 +307,9 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
if prompt_embeds is None: if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device) prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
prompt_embeds = prompt_embeds[:, :max_sequence_length]
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
_, seq_len, _ = prompt_embeds.shape _, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
...@@ -390,21 +393,6 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): ...@@ -390,21 +393,6 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
if max_sequence_length is not None and max_sequence_length > 1024: if max_sequence_length is not None and max_sequence_length > 1024:
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}") raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
@staticmethod
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._prepare_latent_image_ids
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height, width, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.reshape(
latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)
@staticmethod @staticmethod
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents # Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
def _pack_latents(latents, batch_size, num_channels_latents, height, width): def _pack_latents(latents, batch_size, num_channels_latents, height, width):
...@@ -492,8 +480,7 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): ...@@ -492,8 +480,7 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.") raise ValueError(f"Expected image dims 4 or 5, got {image.dim()}.")
if latents is not None: if latents is not None:
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) return latents.to(device=device, dtype=dtype)
return latents.to(device=device, dtype=dtype), latent_image_ids
image = image.to(device=device, dtype=dtype) image = image.to(device=device, dtype=dtype)
if image.shape[1] != self.latent_channels: if image.shape[1] != self.latent_channels:
...@@ -524,9 +511,7 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): ...@@ -524,9 +511,7 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width) image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width)
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width) latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) return latents, noise, image_latents
return latents, noise, image_latents, latent_image_ids
def prepare_mask_latents( def prepare_mask_latents(
self, self,
...@@ -859,7 +844,7 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): ...@@ -859,7 +844,7 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
# 5. Prepare latent variables # 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels // 4 num_channels_latents = self.transformer.config.in_channels // 4
latents, noise, image_latents, latent_image_ids = self.prepare_latents( latents, noise, image_latents = self.prepare_latents(
init_image, init_image,
latent_timestep, latent_timestep,
batch_size * num_images_per_prompt, batch_size * num_images_per_prompt,
...@@ -894,7 +879,7 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin): ...@@ -894,7 +879,7 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
generator, generator,
) )
img_shapes = [(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)] * batch_size img_shapes = [[(1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2)]] * batch_size
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps) self._num_timesteps = len(timesteps)
......
...@@ -1742,6 +1742,21 @@ class PixArtSigmaPipeline(metaclass=DummyObject): ...@@ -1742,6 +1742,21 @@ class PixArtSigmaPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class QwenImageEditPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class QwenImageImg2ImgPipeline(metaclass=DummyObject): class QwenImageImg2ImgPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment