Unverified Commit cc5b31ff authored by Steven Liu's avatar Steven Liu Committed by GitHub
Browse files

[docs] Migrate syntax (#12390)

* change syntax

* make style
parent d7a1a036
...@@ -324,11 +324,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin): ...@@ -324,11 +324,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused. are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}> > [!WARNING] > This API is 🧪 experimental.
This API is 🧪 experimental.
</Tip>
""" """
self.original_attn_processors = None self.original_attn_processors = None
...@@ -348,11 +344,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin): ...@@ -348,11 +344,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
def unfuse_qkv_projections(self): def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled. """Disables the fused QKV projection if enabled.
<Tip warning={true}> > [!WARNING] > This API is 🧪 experimental.
This API is 🧪 experimental.
</Tip>
""" """
if self.original_attn_processors is not None: if self.original_attn_processors is not None:
......
...@@ -258,11 +258,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin): ...@@ -258,11 +258,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused. are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}> > [!WARNING] > This API is 🧪 experimental.
This API is 🧪 experimental.
</Tip>
""" """
self.original_attn_processors = None self.original_attn_processors = None
...@@ -282,11 +278,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin): ...@@ -282,11 +278,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
def unfuse_qkv_projections(self): def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled. """Disables the fused QKV projection if enabled.
<Tip warning={true}> > [!WARNING] > This API is 🧪 experimental.
This API is 🧪 experimental.
</Tip>
""" """
if self.original_attn_processors is not None: if self.original_attn_processors is not None:
......
...@@ -280,11 +280,7 @@ class SD3Transformer2DModel( ...@@ -280,11 +280,7 @@ class SD3Transformer2DModel(
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused. are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}> > [!WARNING] > This API is 🧪 experimental.
This API is 🧪 experimental.
</Tip>
""" """
self.original_attn_processors = None self.original_attn_processors = None
...@@ -304,11 +300,7 @@ class SD3Transformer2DModel( ...@@ -304,11 +300,7 @@ class SD3Transformer2DModel(
def unfuse_qkv_projections(self): def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled. """Disables the fused QKV projection if enabled.
<Tip warning={true}> > [!WARNING] > This API is 🧪 experimental.
This API is 🧪 experimental.
</Tip>
""" """
if self.original_attn_processors is not None: if self.original_attn_processors is not None:
......
...@@ -872,11 +872,7 @@ class UNet2DConditionModel( ...@@ -872,11 +872,7 @@ class UNet2DConditionModel(
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused. are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}> > [!WARNING] > This API is 🧪 experimental.
This API is 🧪 experimental.
</Tip>
""" """
self.original_attn_processors = None self.original_attn_processors = None
...@@ -895,11 +891,7 @@ class UNet2DConditionModel( ...@@ -895,11 +891,7 @@ class UNet2DConditionModel(
def unfuse_qkv_projections(self): def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled. """Disables the fused QKV projection if enabled.
<Tip warning={true}> > [!WARNING] > This API is 🧪 experimental.
This API is 🧪 experimental.
</Tip>
""" """
if self.original_attn_processors is not None: if self.original_attn_processors is not None:
......
...@@ -508,11 +508,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -508,11 +508,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused. are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}> > [!WARNING] > This API is 🧪 experimental.
This API is 🧪 experimental.
</Tip>
""" """
self.original_attn_processors = None self.original_attn_processors = None
...@@ -532,11 +528,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin) ...@@ -532,11 +528,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
def unfuse_qkv_projections(self): def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled. """Disables the fused QKV projection if enabled.
<Tip warning={true}> > [!WARNING] > This API is 🧪 experimental.
This API is 🧪 experimental.
</Tip>
""" """
if self.original_attn_processors is not None: if self.original_attn_processors is not None:
......
...@@ -472,11 +472,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -472,11 +472,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused. are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}> > [!WARNING] > This API is 🧪 experimental.
This API is 🧪 experimental.
</Tip>
""" """
self.original_attn_processors = None self.original_attn_processors = None
...@@ -496,11 +492,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): ...@@ -496,11 +492,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
def unfuse_qkv_projections(self): def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled. """Disables the fused QKV projection if enabled.
<Tip warning={true}> > [!WARNING] > This API is 🧪 experimental.
This API is 🧪 experimental.
</Tip>
""" """
if self.original_attn_processors is not None: if self.original_attn_processors is not None:
......
...@@ -1911,11 +1911,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft ...@@ -1911,11 +1911,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused. are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}> > [!WARNING] > This API is 🧪 experimental.
This API is 🧪 experimental.
</Tip>
""" """
self.original_attn_processors = None self.original_attn_processors = None
...@@ -1935,11 +1931,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft ...@@ -1935,11 +1931,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
def unfuse_qkv_projections(self): def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled. """Disables the fused QKV projection if enabled.
<Tip warning={true}> > [!WARNING] > This API is 🧪 experimental.
This API is 🧪 experimental.
</Tip>
""" """
if self.original_attn_processors is not None: if self.original_attn_processors is not None:
......
...@@ -286,11 +286,7 @@ class ComponentsManager: ...@@ -286,11 +286,7 @@ class ComponentsManager:
encoders, etc.) across different modular pipelines. It includes features for duplicate detection, memory encoders, etc.) across different modular pipelines. It includes features for duplicate detection, memory
management, and component organization. management, and component organization.
<Tip warning={true}> > [!WARNING] > This is an experimental feature and is likely to change in the future.
This is an experimental feature and is likely to change in the future.
</Tip>
Example: Example:
```python ```python
......
...@@ -25,11 +25,7 @@ class FluxModularPipeline(ModularPipeline, FluxLoraLoaderMixin, TextualInversion ...@@ -25,11 +25,7 @@ class FluxModularPipeline(ModularPipeline, FluxLoraLoaderMixin, TextualInversion
""" """
A ModularPipeline for Flux. A ModularPipeline for Flux.
<Tip warning={true}> > [!WARNING] > This is an experimental feature and is likely to change in the future.
This is an experimental feature and is likely to change in the future.
</Tip>
""" """
default_blocks_name = "FluxAutoBlocks" default_blocks_name = "FluxAutoBlocks"
......
...@@ -226,11 +226,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): ...@@ -226,11 +226,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
[`ModularPipelineBlocks`] provides method to load and save the definition of pipeline blocks. [`ModularPipelineBlocks`] provides method to load and save the definition of pipeline blocks.
<Tip warning={true}> > [!WARNING] > This is an experimental feature and is likely to change in the future.
This is an experimental feature and is likely to change in the future.
</Tip>
""" """
config_name = "modular_config.json" config_name = "modular_config.json"
...@@ -525,11 +521,7 @@ class AutoPipelineBlocks(ModularPipelineBlocks): ...@@ -525,11 +521,7 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
library implements for all the pipeline blocks (such as loading or saving etc.) library implements for all the pipeline blocks (such as loading or saving etc.)
<Tip warning={true}> > [!WARNING] > This is an experimental feature and is likely to change in the future.
This is an experimental feature and is likely to change in the future.
</Tip>
Attributes: Attributes:
block_classes: List of block classes to be used block_classes: List of block classes to be used
...@@ -787,11 +779,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks): ...@@ -787,11 +779,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
library implements for all the pipeline blocks (such as loading or saving etc.) library implements for all the pipeline blocks (such as loading or saving etc.)
<Tip warning={true}> > [!WARNING] > This is an experimental feature and is likely to change in the future.
This is an experimental feature and is likely to change in the future.
</Tip>
Attributes: Attributes:
block_classes: List of block classes to be used block_classes: List of block classes to be used
...@@ -1146,11 +1134,7 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks): ...@@ -1146,11 +1134,7 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
library implements for all the pipeline blocks (such as loading or saving etc.) library implements for all the pipeline blocks (such as loading or saving etc.)
<Tip warning={true}> > [!WARNING] > This is an experimental feature and is likely to change in the future.
This is an experimental feature and is likely to change in the future.
</Tip>
Attributes: Attributes:
block_classes: List of block classes to be used block_classes: List of block classes to be used
...@@ -1433,11 +1417,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin): ...@@ -1433,11 +1417,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
""" """
Base class for all Modular pipelines. Base class for all Modular pipelines.
<Tip warning={true}> > [!WARNING] > This is an experimental feature and is likely to change in the future.
This is an experimental feature and is likely to change in the future.
</Tip>
Args: Args:
blocks: ModularPipelineBlocks, the blocks to be used in the pipeline blocks: ModularPipelineBlocks, the blocks to be used in the pipeline
...@@ -2173,12 +2153,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin): ...@@ -2173,12 +2153,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the
arguments of `self.to(*args, **kwargs).` arguments of `self.to(*args, **kwargs).`
<Tip> > [!TIP] > If the pipeline already has the correct torch.dtype and torch.device, then it is returned as is.
Otherwise, > the returned pipeline is a copy of self with the desired torch.dtype and torch.device.
If the pipeline already has the correct torch.dtype and torch.device, then it is returned as is. Otherwise,
the returned pipeline is a copy of self with the desired torch.dtype and torch.device.
</Tip>
Here are the ways to call `to`: Here are the ways to call `to`:
......
import json
import logging
import os
from pathlib import Path
from typing import List, Optional, Tuple, Union
import numpy as np
import PIL
import torch
from ..configuration_utils import ConfigMixin
from ..image_processor import PipelineImageInput
from .modular_pipeline import ModularPipelineBlocks, SequentialPipelineBlocks
from .modular_pipeline_utils import InputParam
logger = logging.getLogger(__name__)
# YiYi Notes: this is actually for SDXL, put it here for now
SDXL_INPUTS_SCHEMA = {
"prompt": InputParam(
"prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"
),
"prompt_2": InputParam(
"prompt_2",
type_hint=Union[str, List[str]],
description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2",
),
"negative_prompt": InputParam(
"negative_prompt",
type_hint=Union[str, List[str]],
description="The prompt or prompts not to guide the image generation",
),
"negative_prompt_2": InputParam(
"negative_prompt_2",
type_hint=Union[str, List[str]],
description="The negative prompt or prompts for text_encoder_2",
),
"cross_attention_kwargs": InputParam(
"cross_attention_kwargs",
type_hint=Optional[dict],
description="Kwargs dictionary passed to the AttentionProcessor",
),
"clip_skip": InputParam(
"clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"
),
"image": InputParam(
"image",
type_hint=PipelineImageInput,
required=True,
description="The image(s) to modify for img2img or inpainting",
),
"mask_image": InputParam(
"mask_image",
type_hint=PipelineImageInput,
required=True,
description="Mask image for inpainting, white pixels will be repainted",
),
"generator": InputParam(
"generator",
type_hint=Optional[Union[torch.Generator, List[torch.Generator]]],
description="Generator(s) for deterministic generation",
),
"height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"),
"width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"),
"num_images_per_prompt": InputParam(
"num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"
),
"num_inference_steps": InputParam(
"num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"
),
"timesteps": InputParam(
"timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"
),
"sigmas": InputParam(
"sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"
),
"denoising_end": InputParam(
"denoising_end",
type_hint=Optional[float],
description="Fraction of denoising process to complete before termination",
),
# YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999
"strength": InputParam(
"strength", type_hint=float, default=0.3, description="How much to transform the reference image"
),
"denoising_start": InputParam(
"denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"
),
"latents": InputParam(
"latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"
),
"padding_mask_crop": InputParam(
"padding_mask_crop",
type_hint=Optional[Tuple[int, int]],
description="Size of margin in crop for image and mask",
),
"original_size": InputParam(
"original_size",
type_hint=Optional[Tuple[int, int]],
description="Original size of the image for SDXL's micro-conditioning",
),
"target_size": InputParam(
"target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"
),
"negative_original_size": InputParam(
"negative_original_size",
type_hint=Optional[Tuple[int, int]],
description="Negative conditioning based on image resolution",
),
"negative_target_size": InputParam(
"negative_target_size",
type_hint=Optional[Tuple[int, int]],
description="Negative conditioning based on target resolution",
),
"crops_coords_top_left": InputParam(
"crops_coords_top_left",
type_hint=Tuple[int, int],
default=(0, 0),
description="Top-left coordinates for SDXL's micro-conditioning",
),
"negative_crops_coords_top_left": InputParam(
"negative_crops_coords_top_left",
type_hint=Tuple[int, int],
default=(0, 0),
description="Negative conditioning crop coordinates",
),
"aesthetic_score": InputParam(
"aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"
),
"negative_aesthetic_score": InputParam(
"negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"
),
"eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"),
"output_type": InputParam(
"output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"
),
"ip_adapter_image": InputParam(
"ip_adapter_image",
type_hint=PipelineImageInput,
required=True,
description="Image(s) to be used as IP adapter",
),
"control_image": InputParam(
"control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"
),
"control_guidance_start": InputParam(
"control_guidance_start",
type_hint=Union[float, List[float]],
default=0.0,
description="When ControlNet starts applying",
),
"control_guidance_end": InputParam(
"control_guidance_end",
type_hint=Union[float, List[float]],
default=1.0,
description="When ControlNet stops applying",
),
"controlnet_conditioning_scale": InputParam(
"controlnet_conditioning_scale",
type_hint=Union[float, List[float]],
default=1.0,
description="Scale factor for ControlNet outputs",
),
"guess_mode": InputParam(
"guess_mode",
type_hint=bool,
default=False,
description="Enables ControlNet encoder to recognize input without prompts",
),
"control_mode": InputParam(
"control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet"
),
}
SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
"prompt_embeds": InputParam(
"prompt_embeds",
type_hint=torch.Tensor,
required=True,
description="Text embeddings used to guide image generation",
),
"negative_prompt_embeds": InputParam(
"negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"
),
"pooled_prompt_embeds": InputParam(
"pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"
),
"negative_pooled_prompt_embeds": InputParam(
"negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"
),
"batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"),
"dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
"preprocess_kwargs": InputParam(
"preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"
),
"latents": InputParam(
"latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"
),
"timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
"num_inference_steps": InputParam(
"num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"
),
"latent_timestep": InputParam(
"latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"
),
"image_latents": InputParam(
"image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"
),
"mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"),
"masked_image_latents": InputParam(
"masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"
),
"add_time_ids": InputParam(
"add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"
),
"negative_add_time_ids": InputParam(
"negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"
),
"timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
"noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
"crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
"ip_adapter_embeds": InputParam(
"ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"
),
"negative_ip_adapter_embeds": InputParam(
"negative_ip_adapter_embeds",
type_hint=List[torch.Tensor],
description="Negative image embeddings for IP-Adapter",
),
"images": InputParam(
"images",
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
required=True,
description="Generated images",
),
}
SDXL_PARAM_SCHEMA = {**SDXL_INPUTS_SCHEMA, **SDXL_INTERMEDIATE_INPUTS_SCHEMA}
DEFAULT_PARAM_MAPS = {
"prompt": {
"label": "Prompt",
"type": "string",
"default": "a bear sitting in a chair drinking a milkshake",
"display": "textarea",
},
"negative_prompt": {
"label": "Negative Prompt",
"type": "string",
"default": "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
"display": "textarea",
},
"num_inference_steps": {
"label": "Steps",
"type": "int",
"default": 25,
"min": 1,
"max": 1000,
},
"seed": {
"label": "Seed",
"type": "int",
"default": 0,
"min": 0,
"display": "random",
},
"width": {
"label": "Width",
"type": "int",
"display": "text",
"default": 1024,
"min": 8,
"max": 8192,
"step": 8,
"group": "dimensions",
},
"height": {
"label": "Height",
"type": "int",
"display": "text",
"default": 1024,
"min": 8,
"max": 8192,
"step": 8,
"group": "dimensions",
},
"images": {
"label": "Images",
"type": "image",
"display": "output",
},
"image": {
"label": "Image",
"type": "image",
"display": "input",
},
}
DEFAULT_TYPE_MAPS = {
"int": {
"type": "int",
"default": 0,
"min": 0,
},
"float": {
"type": "float",
"default": 0.0,
"min": 0.0,
},
"str": {
"type": "string",
"default": "",
},
"bool": {
"type": "boolean",
"default": False,
},
"image": {
"type": "image",
},
}
DEFAULT_MODEL_KEYS = ["unet", "vae", "text_encoder", "tokenizer", "controlnet", "transformer", "image_encoder"]
DEFAULT_CATEGORY = "Modular Diffusers"
DEFAULT_EXCLUDE_MODEL_KEYS = ["processor", "feature_extractor", "safety_checker"]
DEFAULT_PARAMS_GROUPS_KEYS = {
"text_encoders": ["text_encoder", "tokenizer"],
"ip_adapter_embeds": ["ip_adapter_embeds"],
"prompt_embeddings": ["prompt_embeds"],
}
def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS):
"""
Get the group name for a given parameter name, if not part of a group, return None e.g. "prompt_embeds" ->
"text_embeds", "text_encoder" -> "text_encoders", "prompt" -> None
"""
if name is None:
return None
for group_name, group_keys in group_params_keys.items():
for group_key in group_keys:
if group_key in name:
return group_name
return None
class ModularNode(ConfigMixin):
"""
A ModularNode is a base class to build UI nodes using diffusers. Currently only supports Mellon. It is a wrapper
around a ModularPipelineBlocks object.
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
config_name = "node_config.json"
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
trust_remote_code: Optional[bool] = None,
**kwargs,
):
blocks = ModularPipelineBlocks.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
)
return cls(blocks, **kwargs)
def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs):
self.blocks = blocks
if label is None:
label = self.blocks.__class__.__name__
# blocks param name -> mellon param name
self.name_mapping = {}
input_params = {}
# pass or create a default param dict for each input
# e.g. for prompt,
# prompt = {
# "name": "text_input", # the name of the input in node definition, could be different from the input name in diffusers
# "label": "Prompt",
# "type": "string",
# "default": "a bear sitting in a chair drinking a milkshake",
# "display": "textarea"}
# if type is not specified, it'll be a "custom" param of its own type
# e.g. you can pass ModularNode(scheduler = {name :"scheduler"})
# it will get this spec in node definition {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}}
# name can be a dict, in that case, it is part of a "dict" input in mellon nodes, e.g. text_encoder= {name: {"text_encoders": "text_encoder"}}
inputs = self.blocks.inputs + self.blocks.intermediate_inputs
for inp in inputs:
param = kwargs.pop(inp.name, None)
if param:
# user can pass a param dict for all inputs, e.g. ModularNode(prompt = {...})
input_params[inp.name] = param
mellon_name = param.pop("name", inp.name)
if mellon_name != inp.name:
self.name_mapping[inp.name] = mellon_name
continue
if inp.name not in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name):
continue
if inp.name in DEFAULT_PARAM_MAPS:
# first check if it's in the default param map, if so, directly use that
param = DEFAULT_PARAM_MAPS[inp.name].copy()
elif get_group_name(inp.name):
param = get_group_name(inp.name)
if inp.name not in self.name_mapping:
self.name_mapping[inp.name] = param
else:
# if not, check if it's in the SDXL input schema, if so,
# 1. use the type hint to determine the type
# 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}}
if inp.type_hint is not None:
type_str = str(inp.type_hint).lower()
else:
inp_spec = SDXL_PARAM_SCHEMA.get(inp.name, None)
type_str = str(inp_spec.type_hint).lower() if inp_spec else ""
for type_key, type_param in DEFAULT_TYPE_MAPS.items():
if type_key in type_str:
param = type_param.copy()
param["label"] = inp.name
param["display"] = "input"
break
else:
param = inp.name
# add the param dict to the inp_params dict
input_params[inp.name] = param
component_params = {}
for comp in self.blocks.expected_components:
param = kwargs.pop(comp.name, None)
if param:
component_params[comp.name] = param
mellon_name = param.pop("name", comp.name)
if mellon_name != comp.name:
self.name_mapping[comp.name] = mellon_name
continue
to_exclude = False
for exclude_key in DEFAULT_EXCLUDE_MODEL_KEYS:
if exclude_key in comp.name:
to_exclude = True
break
if to_exclude:
continue
if get_group_name(comp.name):
param = get_group_name(comp.name)
if comp.name not in self.name_mapping:
self.name_mapping[comp.name] = param
elif comp.name in DEFAULT_MODEL_KEYS:
param = {"label": comp.name, "type": "diffusers_auto_model", "display": "input"}
else:
param = comp.name
# add the param dict to the model_params dict
component_params[comp.name] = param
output_params = {}
if isinstance(self.blocks, SequentialPipelineBlocks):
last_block_name = list(self.blocks.sub_blocks.keys())[-1]
outputs = self.blocks.sub_blocks[last_block_name].intermediate_outputs
else:
outputs = self.blocks.intermediate_outputs
for out in outputs:
param = kwargs.pop(out.name, None)
if param:
output_params[out.name] = param
mellon_name = param.pop("name", out.name)
if mellon_name != out.name:
self.name_mapping[out.name] = mellon_name
continue
if out.name in DEFAULT_PARAM_MAPS:
param = DEFAULT_PARAM_MAPS[out.name].copy()
param["display"] = "output"
else:
group_name = get_group_name(out.name)
if group_name:
param = group_name
if out.name not in self.name_mapping:
self.name_mapping[out.name] = param
else:
param = out.name
# add the param dict to the outputs dict
output_params[out.name] = param
if len(kwargs) > 0:
logger.warning(f"Unused kwargs: {kwargs}")
register_dict = {
"category": category,
"label": label,
"input_params": input_params,
"component_params": component_params,
"output_params": output_params,
"name_mapping": self.name_mapping,
}
self.register_to_config(**register_dict)
def setup(self, components_manager, collection=None):
self.pipeline = self.blocks.init_pipeline(components_manager=components_manager, collection=collection)
self._components_manager = components_manager
@property
def mellon_config(self):
return self._convert_to_mellon_config()
def _convert_to_mellon_config(self):
node = {}
node["label"] = self.config.label
node["category"] = self.config.category
node_param = {}
for inp_name, inp_param in self.config.input_params.items():
if inp_name in self.name_mapping:
mellon_name = self.name_mapping[inp_name]
else:
mellon_name = inp_name
if isinstance(inp_param, str):
param = {
"label": inp_param,
"type": inp_param,
"display": "input",
}
else:
param = inp_param
if mellon_name not in node_param:
node_param[mellon_name] = param
else:
logger.debug(f"Input param {mellon_name} already exists in node_param, skipping {inp_name}")
for comp_name, comp_param in self.config.component_params.items():
if comp_name in self.name_mapping:
mellon_name = self.name_mapping[comp_name]
else:
mellon_name = comp_name
if isinstance(comp_param, str):
param = {
"label": comp_param,
"type": comp_param,
"display": "input",
}
else:
param = comp_param
if mellon_name not in node_param:
node_param[mellon_name] = param
else:
logger.debug(f"Component param {comp_param} already exists in node_param, skipping {comp_name}")
for out_name, out_param in self.config.output_params.items():
if out_name in self.name_mapping:
mellon_name = self.name_mapping[out_name]
else:
mellon_name = out_name
if isinstance(out_param, str):
param = {
"label": out_param,
"type": out_param,
"display": "output",
}
else:
param = out_param
if mellon_name not in node_param:
node_param[mellon_name] = param
else:
logger.debug(f"Output param {out_param} already exists in node_param, skipping {out_name}")
node["params"] = node_param
return node
def save_mellon_config(self, file_path):
"""
Save the Mellon configuration to a JSON file.
Args:
file_path (str or Path): Path where the JSON file will be saved
Returns:
Path: Path to the saved config file
"""
file_path = Path(file_path)
# Create directory if it doesn't exist
os.makedirs(file_path.parent, exist_ok=True)
# Create a combined dictionary with module definition and name mapping
config = {"module": self.mellon_config, "name_mapping": self.name_mapping}
# Save the config to file
with open(file_path, "w", encoding="utf-8") as f:
json.dump(config, f, indent=2)
logger.info(f"Mellon config and name mapping saved to {file_path}")
return file_path
@classmethod
def load_mellon_config(cls, file_path):
"""
Load a Mellon configuration from a JSON file.
Args:
file_path (str or Path): Path to the JSON file containing Mellon config
Returns:
dict: The loaded combined configuration containing 'module' and 'name_mapping'
"""
file_path = Path(file_path)
if not file_path.exists():
raise FileNotFoundError(f"Config file not found: {file_path}")
with open(file_path, "r", encoding="utf-8") as f:
config = json.load(f)
logger.info(f"Mellon config loaded from {file_path}")
return config
def process_inputs(self, **kwargs):
params_components = {}
for comp_name, comp_param in self.config.component_params.items():
logger.debug(f"component: {comp_name}")
mellon_comp_name = self.name_mapping.get(comp_name, comp_name)
if mellon_comp_name in kwargs:
if isinstance(kwargs[mellon_comp_name], dict) and comp_name in kwargs[mellon_comp_name]:
comp = kwargs[mellon_comp_name].pop(comp_name)
else:
comp = kwargs.pop(mellon_comp_name)
if comp:
params_components[comp_name] = self._components_manager.get_one(comp["model_id"])
params_run = {}
for inp_name, inp_param in self.config.input_params.items():
logger.debug(f"input: {inp_name}")
mellon_inp_name = self.name_mapping.get(inp_name, inp_name)
if mellon_inp_name in kwargs:
if isinstance(kwargs[mellon_inp_name], dict) and inp_name in kwargs[mellon_inp_name]:
inp = kwargs[mellon_inp_name].pop(inp_name)
else:
inp = kwargs.pop(mellon_inp_name)
if inp is not None:
params_run[inp_name] = inp
return_output_names = list(self.config.output_params.keys())
return params_components, params_run, return_output_names
def execute(self, **kwargs):
params_components, params_run, return_output_names = self.process_inputs(**kwargs)
self.pipeline.update_components(**params_components)
output = self.pipeline(**params_run, output=return_output_names)
return output
...@@ -97,11 +97,7 @@ class QwenImageModularPipeline(ModularPipeline, QwenImageLoraLoaderMixin): ...@@ -97,11 +97,7 @@ class QwenImageModularPipeline(ModularPipeline, QwenImageLoraLoaderMixin):
""" """
A ModularPipeline for QwenImage. A ModularPipeline for QwenImage.
<Tip warning={true}> > [!WARNING] > This is an experimental feature and is likely to change in the future.
This is an experimental feature and is likely to change in the future.
</Tip>
""" """
default_blocks_name = "QwenImageAutoBlocks" default_blocks_name = "QwenImageAutoBlocks"
...@@ -153,11 +149,7 @@ class QwenImageEditModularPipeline(ModularPipeline, QwenImageLoraLoaderMixin): ...@@ -153,11 +149,7 @@ class QwenImageEditModularPipeline(ModularPipeline, QwenImageLoraLoaderMixin):
""" """
A ModularPipeline for QwenImage-Edit. A ModularPipeline for QwenImage-Edit.
<Tip warning={true}> > [!WARNING] > This is an experimental feature and is likely to change in the future.
This is an experimental feature and is likely to change in the future.
</Tip>
""" """
default_blocks_name = "QwenImageEditAutoBlocks" default_blocks_name = "QwenImageEditAutoBlocks"
......
...@@ -47,11 +47,7 @@ class StableDiffusionXLModularPipeline( ...@@ -47,11 +47,7 @@ class StableDiffusionXLModularPipeline(
""" """
A ModularPipeline for Stable Diffusion XL. A ModularPipeline for Stable Diffusion XL.
<Tip warning={true}> > [!WARNING] > This is an experimental feature and is likely to change in the future.
This is an experimental feature and is likely to change in the future.
</Tip>
""" """
default_blocks_name = "StableDiffusionXLAutoBlocks" default_blocks_name = "StableDiffusionXLAutoBlocks"
......
...@@ -30,11 +30,7 @@ class WanModularPipeline( ...@@ -30,11 +30,7 @@ class WanModularPipeline(
""" """
A ModularPipeline for Wan. A ModularPipeline for Wan.
<Tip warning={true}> > [!WARNING] > This is an experimental feature and is likely to change in the future.
This is an experimental feature and is likely to change in the future.
</Tip>
""" """
default_blocks_name = "WanAutoBlocks" default_blocks_name = "WanAutoBlocks"
......
...@@ -407,12 +407,8 @@ class AutoPipelineForText2Image(ConfigMixin): ...@@ -407,12 +407,8 @@ class AutoPipelineForText2Image(ConfigMixin):
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`. loading `from_flax`.
<Tip> > [!TIP] > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in
with `hf > auth login`.
To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with `hf
auth login`.
</Tip>
Examples: Examples:
...@@ -702,12 +698,8 @@ class AutoPipelineForImage2Image(ConfigMixin): ...@@ -702,12 +698,8 @@ class AutoPipelineForImage2Image(ConfigMixin):
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`. loading `from_flax`.
<Tip> > [!TIP] > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in
with `hf > auth login`.
To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with `hf
auth login`.
</Tip>
Examples: Examples:
...@@ -1012,12 +1004,8 @@ class AutoPipelineForInpainting(ConfigMixin): ...@@ -1012,12 +1004,8 @@ class AutoPipelineForInpainting(ConfigMixin):
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`. loading `from_flax`.
<Tip> > [!TIP] > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in
with `hf > auth login`.
To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with `hf
auth login`.
</Tip>
Examples: Examples:
......
...@@ -146,16 +146,13 @@ class StableDiffusionControlNetInpaintPipeline( ...@@ -146,16 +146,13 @@ class StableDiffusionControlNetInpaintPipeline(
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
<Tip> > [!TIP] > This pipeline can be used with checkpoints that have been specifically fine-tuned for inpainting >
This pipeline can be used with checkpoints that have been specifically fine-tuned for inpainting
([stable-diffusion-v1-5/stable-diffusion-inpainting](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting)) ([stable-diffusion-v1-5/stable-diffusion-inpainting](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting))
as well as default text-to-image Stable Diffusion checkpoints > as well as default text-to-image Stable Diffusion checkpoints >
([stable-diffusion-v1-5/stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)). ([stable-diffusion-v1-5/stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)).
Default text-to-image Stable Diffusion checkpoints might be preferable for ControlNets that have been fine-tuned on > Default text-to-image Stable Diffusion checkpoints might be preferable for ControlNets that have been fine-tuned
those, such as [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint). on > those, such as
[lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint).
</Tip>
Args: Args:
vae ([`AutoencoderKL`]): vae ([`AutoencoderKL`]):
......
...@@ -394,12 +394,8 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline): ...@@ -394,12 +394,8 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
jit (`bool`, defaults to `False`): jit (`bool`, defaults to `False`):
Whether to run `pmap` versions of the generation and safety scoring functions. Whether to run `pmap` versions of the generation and safety scoring functions.
<Tip warning={true}> > [!WARNING] > This argument exists because `__call__` is not yet end-to-end pmap-able. It will be
removed in a > future release.
This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a
future release.
</Tip>
Examples: Examples:
......
...@@ -1000,11 +1000,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -1000,11 +1000,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused. are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}> > [!WARNING] > This API is 🧪 experimental.
This API is 🧪 experimental.
</Tip>
""" """
self.original_attn_processors = None self.original_attn_processors = None
...@@ -1021,11 +1017,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin): ...@@ -1021,11 +1017,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
def unfuse_qkv_projections(self): def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled. """Disables the fused QKV projection if enabled.
<Tip warning={true}> > [!WARNING] > This API is 🧪 experimental.
This API is 🧪 experimental.
</Tip>
""" """
if self.original_attn_processors is not None: if self.original_attn_processors is not None:
......
...@@ -150,17 +150,13 @@ class StableDiffusionControlNetPAGInpaintPipeline( ...@@ -150,17 +150,13 @@ class StableDiffusionControlNetPAGInpaintPipeline(
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
<Tip> > [!TIP] > This pipeline can be used with checkpoints that have been specifically fine-tuned for inpainting >
([runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting)) as well as >
This pipeline can be used with checkpoints that have been specifically fine-tuned for inpainting default text-to-image Stable Diffusion checkpoints >
([runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting)) as well as ([runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5)). Default text-to-image >
default text-to-image Stable Diffusion checkpoints Stable Diffusion checkpoints might be preferable for ControlNets that have been fine-tuned on those, such as >
([runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5)). Default text-to-image
Stable Diffusion checkpoints might be preferable for ControlNets that have been fine-tuned on those, such as
[lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint). [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint).
</Tip>
Args: Args:
vae ([`AutoencoderKL`]): vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
......
...@@ -158,11 +158,7 @@ def prepare_mask_and_masked_image(image, mask): ...@@ -158,11 +158,7 @@ def prepare_mask_and_masked_image(image, mask):
class PaintByExamplePipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin): class PaintByExamplePipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin):
_last_supported_version = "0.33.1" _last_supported_version = "0.33.1"
r""" r"""
<Tip warning={true}> > [!WARNING] > 🧪 This is an experimental feature!
🧪 This is an experimental feature!
</Tip>
Pipeline for image-guided image inpainting using Stable Diffusion. Pipeline for image-guided image inpainting using Stable Diffusion.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment