Unverified Commit 8e53cd95 authored by naykun's avatar naykun Committed by GitHub
Browse files

Qwen-Image (#12055)



* (feat): qwen-image integration

* fix(qwen-image):
- remove unused logics related to controlnet/ip-adapter

* fix(qwen-image):
- compatible with attention dispatcher
- cond cache support

* fix(qwen-image):
- cond cache registry
- attention backend argument
- fix copies

* fix(qwen-image):
- remove local test

* Update src/diffusers/models/transformers/transformer_qwenimage.py

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent 359b605f
...@@ -174,6 +174,7 @@ else: ...@@ -174,6 +174,7 @@ else:
"AutoencoderKLLTXVideo", "AutoencoderKLLTXVideo",
"AutoencoderKLMagvit", "AutoencoderKLMagvit",
"AutoencoderKLMochi", "AutoencoderKLMochi",
"AutoencoderKLQwenImage",
"AutoencoderKLTemporalDecoder", "AutoencoderKLTemporalDecoder",
"AutoencoderKLWan", "AutoencoderKLWan",
"AutoencoderOobleck", "AutoencoderOobleck",
...@@ -215,6 +216,7 @@ else: ...@@ -215,6 +216,7 @@ else:
"OmniGenTransformer2DModel", "OmniGenTransformer2DModel",
"PixArtTransformer2DModel", "PixArtTransformer2DModel",
"PriorTransformer", "PriorTransformer",
"QwenImageTransformer2DModel",
"SanaControlNetModel", "SanaControlNetModel",
"SanaTransformer2DModel", "SanaTransformer2DModel",
"SD3ControlNetModel", "SD3ControlNetModel",
...@@ -486,6 +488,7 @@ else: ...@@ -486,6 +488,7 @@ else:
"PixArtAlphaPipeline", "PixArtAlphaPipeline",
"PixArtSigmaPAGPipeline", "PixArtSigmaPAGPipeline",
"PixArtSigmaPipeline", "PixArtSigmaPipeline",
"QwenImagePipeline",
"ReduxImageEncoder", "ReduxImageEncoder",
"SanaControlNetPipeline", "SanaControlNetPipeline",
"SanaPAGPipeline", "SanaPAGPipeline",
...@@ -832,6 +835,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -832,6 +835,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLLTXVideo, AutoencoderKLLTXVideo,
AutoencoderKLMagvit, AutoencoderKLMagvit,
AutoencoderKLMochi, AutoencoderKLMochi,
AutoencoderKLQwenImage,
AutoencoderKLTemporalDecoder, AutoencoderKLTemporalDecoder,
AutoencoderKLWan, AutoencoderKLWan,
AutoencoderOobleck, AutoencoderOobleck,
...@@ -873,6 +877,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -873,6 +877,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
OmniGenTransformer2DModel, OmniGenTransformer2DModel,
PixArtTransformer2DModel, PixArtTransformer2DModel,
PriorTransformer, PriorTransformer,
QwenImageTransformer2DModel,
SanaControlNetModel, SanaControlNetModel,
SanaTransformer2DModel, SanaTransformer2DModel,
SD3ControlNetModel, SD3ControlNetModel,
...@@ -1119,6 +1124,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -1119,6 +1124,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
PixArtAlphaPipeline, PixArtAlphaPipeline,
PixArtSigmaPAGPipeline, PixArtSigmaPAGPipeline,
PixArtSigmaPipeline, PixArtSigmaPipeline,
QwenImagePipeline,
ReduxImageEncoder, ReduxImageEncoder,
SanaControlNetPipeline, SanaControlNetPipeline,
SanaPAGPipeline, SanaPAGPipeline,
......
...@@ -153,6 +153,7 @@ def _register_transformer_blocks_metadata(): ...@@ -153,6 +153,7 @@ def _register_transformer_blocks_metadata():
) )
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
from ..models.transformers.transformer_mochi import MochiTransformerBlock from ..models.transformers.transformer_mochi import MochiTransformerBlock
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
from ..models.transformers.transformer_wan import WanTransformerBlock from ..models.transformers.transformer_wan import WanTransformerBlock
# BasicTransformerBlock # BasicTransformerBlock
...@@ -255,6 +256,15 @@ def _register_transformer_blocks_metadata(): ...@@ -255,6 +256,15 @@ def _register_transformer_blocks_metadata():
), ),
) )
# QwenImage
TransformerBlockRegistry.register(
model_class=QwenImageTransformerBlock,
metadata=TransformerBlockMetadata(
return_hidden_states_index=1,
return_encoder_hidden_states_index=0,
),
)
# fmt: off # fmt: off
def _skip_attention___ret___hidden_states(self, *args, **kwargs): def _skip_attention___ret___hidden_states(self, *args, **kwargs):
......
...@@ -38,6 +38,7 @@ if is_torch_available(): ...@@ -38,6 +38,7 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"] _import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"] _import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"] _import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
_import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"]
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"] _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"] _import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"]
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"] _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
...@@ -88,6 +89,7 @@ if is_torch_available(): ...@@ -88,6 +89,7 @@ if is_torch_available():
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"] _import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"] _import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"] _import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
_import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"] _import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"] _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
...@@ -126,6 +128,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -126,6 +128,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLLTXVideo, AutoencoderKLLTXVideo,
AutoencoderKLMagvit, AutoencoderKLMagvit,
AutoencoderKLMochi, AutoencoderKLMochi,
AutoencoderKLQwenImage,
AutoencoderKLTemporalDecoder, AutoencoderKLTemporalDecoder,
AutoencoderKLWan, AutoencoderKLWan,
AutoencoderOobleck, AutoencoderOobleck,
...@@ -177,6 +180,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -177,6 +180,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
OmniGenTransformer2DModel, OmniGenTransformer2DModel,
PixArtTransformer2DModel, PixArtTransformer2DModel,
PriorTransformer, PriorTransformer,
QwenImageTransformer2DModel,
SanaTransformer2DModel, SanaTransformer2DModel,
SD3Transformer2DModel, SD3Transformer2DModel,
SkyReelsV2Transformer3DModel, SkyReelsV2Transformer3DModel,
......
...@@ -8,6 +8,7 @@ from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo ...@@ -8,6 +8,7 @@ from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
from .autoencoder_kl_magvit import AutoencoderKLMagvit from .autoencoder_kl_magvit import AutoencoderKLMagvit
from .autoencoder_kl_mochi import AutoencoderKLMochi from .autoencoder_kl_mochi import AutoencoderKLMochi
from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_kl_wan import AutoencoderKLWan from .autoencoder_kl_wan import AutoencoderKLWan
from .autoencoder_oobleck import AutoencoderOobleck from .autoencoder_oobleck import AutoencoderOobleck
......
This diff is collapsed.
...@@ -30,6 +30,7 @@ if is_torch_available(): ...@@ -30,6 +30,7 @@ if is_torch_available():
from .transformer_lumina2 import Lumina2Transformer2DModel from .transformer_lumina2 import Lumina2Transformer2DModel
from .transformer_mochi import MochiTransformer3DModel from .transformer_mochi import MochiTransformer3DModel
from .transformer_omnigen import OmniGenTransformer2DModel from .transformer_omnigen import OmniGenTransformer2DModel
from .transformer_qwenimage import QwenImageTransformer2DModel
from .transformer_sd3 import SD3Transformer2DModel from .transformer_sd3 import SD3Transformer2DModel
from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel
from .transformer_temporal import TransformerTemporalModel from .transformer_temporal import TransformerTemporalModel
......
This diff is collapsed.
...@@ -387,6 +387,7 @@ else: ...@@ -387,6 +387,7 @@ else:
"SkyReelsV2ImageToVideoPipeline", "SkyReelsV2ImageToVideoPipeline",
"SkyReelsV2Pipeline", "SkyReelsV2Pipeline",
] ]
_import_structure["qwenimage"] = ["QwenImagePipeline"]
try: try:
if not is_onnx_available(): if not is_onnx_available():
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
...@@ -703,6 +704,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -703,6 +704,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .paint_by_example import PaintByExamplePipeline from .paint_by_example import PaintByExamplePipeline
from .pia import PIAPipeline from .pia import PIAPipeline
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
from .qwenimage import QwenImagePipeline
from .sana import SanaControlNetPipeline, SanaPipeline, SanaSprintImg2ImgPipeline, SanaSprintPipeline from .sana import SanaControlNetPipeline, SanaPipeline, SanaSprintImg2ImgPipeline, SanaSprintPipeline
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
......
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_additional_imports = {}
_import_structure = {"pipeline_output": ["QwenImagePipelineOutput", "QwenImagePriorReduxPipelineOutput"]}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["modeling_qwenimage"] = ["ReduxImageEncoder"]
_import_structure["pipeline_qwenimage"] = ["QwenImagePipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipeline_qwenimage import QwenImagePipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
for name, value in _additional_imports.items():
setattr(sys.modules[__name__], name, value)
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
from ...utils import BaseOutput
@dataclass
class QwenImagePipelineOutput(BaseOutput):
"""
Output class for Stable Diffusion pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]
This diff is collapsed.
...@@ -423,6 +423,21 @@ class AutoencoderKLMochi(metaclass=DummyObject): ...@@ -423,6 +423,21 @@ class AutoencoderKLMochi(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class AutoencoderKLQwenImage(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoencoderKLTemporalDecoder(metaclass=DummyObject): class AutoencoderKLTemporalDecoder(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
...@@ -1038,6 +1053,21 @@ class PriorTransformer(metaclass=DummyObject): ...@@ -1038,6 +1053,21 @@ class PriorTransformer(metaclass=DummyObject):
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class QwenImageTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class SanaControlNetModel(metaclass=DummyObject): class SanaControlNetModel(metaclass=DummyObject):
_backends = ["torch"] _backends = ["torch"]
......
...@@ -1742,6 +1742,21 @@ class PixArtSigmaPipeline(metaclass=DummyObject): ...@@ -1742,6 +1742,21 @@ class PixArtSigmaPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"]) requires_backends(cls, ["torch", "transformers"])
class QwenImagePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class ReduxImageEncoder(metaclass=DummyObject): class ReduxImageEncoder(metaclass=DummyObject):
_backends = ["torch", "transformers"] _backends = ["torch", "transformers"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment