Unverified Commit cc5b31ff authored by Steven Liu's avatar Steven Liu Committed by GitHub
Browse files

[docs] Migrate syntax (#12390)

* change syntax

* make style
parent d7a1a036
......@@ -324,11 +324,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
> [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
......@@ -348,11 +344,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
> [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
......
......@@ -258,11 +258,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
> [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
......@@ -282,11 +278,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
> [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
......
......@@ -280,11 +280,7 @@ class SD3Transformer2DModel(
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
> [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
......@@ -304,11 +300,7 @@ class SD3Transformer2DModel(
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
> [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
......
......@@ -872,11 +872,7 @@ class UNet2DConditionModel(
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
> [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
......@@ -895,11 +891,7 @@ class UNet2DConditionModel(
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
> [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
......
......@@ -508,11 +508,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
> [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
......@@ -532,11 +528,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
> [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
......
......@@ -472,11 +472,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
> [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
......@@ -496,11 +492,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
> [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
......
......@@ -1911,11 +1911,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
> [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
......@@ -1935,11 +1931,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
> [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
......
......@@ -286,11 +286,7 @@ class ComponentsManager:
encoders, etc.) across different modular pipelines. It includes features for duplicate detection, memory
management, and component organization.
<Tip warning={true}>
This is an experimental feature and is likely to change in the future.
</Tip>
> [!WARNING] > This is an experimental feature and is likely to change in the future.
Example:
```python
......
......@@ -25,11 +25,7 @@ class FluxModularPipeline(ModularPipeline, FluxLoraLoaderMixin, TextualInversion
"""
A ModularPipeline for Flux.
<Tip warning={true}>
This is an experimental feature and is likely to change in the future.
</Tip>
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
default_blocks_name = "FluxAutoBlocks"
......
......@@ -226,11 +226,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
[`ModularPipelineBlocks`] provides method to load and save the definition of pipeline blocks.
<Tip warning={true}>
This is an experimental feature and is likely to change in the future.
</Tip>
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
config_name = "modular_config.json"
......@@ -525,11 +521,7 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
library implements for all the pipeline blocks (such as loading or saving etc.)
<Tip warning={true}>
This is an experimental feature and is likely to change in the future.
</Tip>
> [!WARNING] > This is an experimental feature and is likely to change in the future.
Attributes:
block_classes: List of block classes to be used
......@@ -787,11 +779,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
library implements for all the pipeline blocks (such as loading or saving etc.)
<Tip warning={true}>
This is an experimental feature and is likely to change in the future.
</Tip>
> [!WARNING] > This is an experimental feature and is likely to change in the future.
Attributes:
block_classes: List of block classes to be used
......@@ -1146,11 +1134,7 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
library implements for all the pipeline blocks (such as loading or saving etc.)
<Tip warning={true}>
This is an experimental feature and is likely to change in the future.
</Tip>
> [!WARNING] > This is an experimental feature and is likely to change in the future.
Attributes:
block_classes: List of block classes to be used
......@@ -1433,11 +1417,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
"""
Base class for all Modular pipelines.
<Tip warning={true}>
This is an experimental feature and is likely to change in the future.
</Tip>
> [!WARNING] > This is an experimental feature and is likely to change in the future.
Args:
blocks: ModularPipelineBlocks, the blocks to be used in the pipeline
......@@ -2173,12 +2153,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
Performs Pipeline dtype and/or device conversion. A torch.dtype and torch.device are inferred from the
arguments of `self.to(*args, **kwargs).`
<Tip>
If the pipeline already has the correct torch.dtype and torch.device, then it is returned as is. Otherwise,
the returned pipeline is a copy of self with the desired torch.dtype and torch.device.
</Tip>
> [!TIP] > If the pipeline already has the correct torch.dtype and torch.device, then it is returned as is.
Otherwise, > the returned pipeline is a copy of self with the desired torch.dtype and torch.device.
Here are the ways to call `to`:
......
import json
import logging
import os
from pathlib import Path
from typing import List, Optional, Tuple, Union
import numpy as np
import PIL
import torch
from ..configuration_utils import ConfigMixin
from ..image_processor import PipelineImageInput
from .modular_pipeline import ModularPipelineBlocks, SequentialPipelineBlocks
from .modular_pipeline_utils import InputParam
logger = logging.getLogger(__name__)
# YiYi Notes: this is actually for SDXL, put it here for now
SDXL_INPUTS_SCHEMA = {
"prompt": InputParam(
"prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"
),
"prompt_2": InputParam(
"prompt_2",
type_hint=Union[str, List[str]],
description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2",
),
"negative_prompt": InputParam(
"negative_prompt",
type_hint=Union[str, List[str]],
description="The prompt or prompts not to guide the image generation",
),
"negative_prompt_2": InputParam(
"negative_prompt_2",
type_hint=Union[str, List[str]],
description="The negative prompt or prompts for text_encoder_2",
),
"cross_attention_kwargs": InputParam(
"cross_attention_kwargs",
type_hint=Optional[dict],
description="Kwargs dictionary passed to the AttentionProcessor",
),
"clip_skip": InputParam(
"clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"
),
"image": InputParam(
"image",
type_hint=PipelineImageInput,
required=True,
description="The image(s) to modify for img2img or inpainting",
),
"mask_image": InputParam(
"mask_image",
type_hint=PipelineImageInput,
required=True,
description="Mask image for inpainting, white pixels will be repainted",
),
"generator": InputParam(
"generator",
type_hint=Optional[Union[torch.Generator, List[torch.Generator]]],
description="Generator(s) for deterministic generation",
),
"height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"),
"width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"),
"num_images_per_prompt": InputParam(
"num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"
),
"num_inference_steps": InputParam(
"num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"
),
"timesteps": InputParam(
"timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"
),
"sigmas": InputParam(
"sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"
),
"denoising_end": InputParam(
"denoising_end",
type_hint=Optional[float],
description="Fraction of denoising process to complete before termination",
),
# YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999
"strength": InputParam(
"strength", type_hint=float, default=0.3, description="How much to transform the reference image"
),
"denoising_start": InputParam(
"denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"
),
"latents": InputParam(
"latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"
),
"padding_mask_crop": InputParam(
"padding_mask_crop",
type_hint=Optional[Tuple[int, int]],
description="Size of margin in crop for image and mask",
),
"original_size": InputParam(
"original_size",
type_hint=Optional[Tuple[int, int]],
description="Original size of the image for SDXL's micro-conditioning",
),
"target_size": InputParam(
"target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"
),
"negative_original_size": InputParam(
"negative_original_size",
type_hint=Optional[Tuple[int, int]],
description="Negative conditioning based on image resolution",
),
"negative_target_size": InputParam(
"negative_target_size",
type_hint=Optional[Tuple[int, int]],
description="Negative conditioning based on target resolution",
),
"crops_coords_top_left": InputParam(
"crops_coords_top_left",
type_hint=Tuple[int, int],
default=(0, 0),
description="Top-left coordinates for SDXL's micro-conditioning",
),
"negative_crops_coords_top_left": InputParam(
"negative_crops_coords_top_left",
type_hint=Tuple[int, int],
default=(0, 0),
description="Negative conditioning crop coordinates",
),
"aesthetic_score": InputParam(
"aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"
),
"negative_aesthetic_score": InputParam(
"negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"
),
"eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"),
"output_type": InputParam(
"output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"
),
"ip_adapter_image": InputParam(
"ip_adapter_image",
type_hint=PipelineImageInput,
required=True,
description="Image(s) to be used as IP adapter",
),
"control_image": InputParam(
"control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"
),
"control_guidance_start": InputParam(
"control_guidance_start",
type_hint=Union[float, List[float]],
default=0.0,
description="When ControlNet starts applying",
),
"control_guidance_end": InputParam(
"control_guidance_end",
type_hint=Union[float, List[float]],
default=1.0,
description="When ControlNet stops applying",
),
"controlnet_conditioning_scale": InputParam(
"controlnet_conditioning_scale",
type_hint=Union[float, List[float]],
default=1.0,
description="Scale factor for ControlNet outputs",
),
"guess_mode": InputParam(
"guess_mode",
type_hint=bool,
default=False,
description="Enables ControlNet encoder to recognize input without prompts",
),
"control_mode": InputParam(
"control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet"
),
}
SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
"prompt_embeds": InputParam(
"prompt_embeds",
type_hint=torch.Tensor,
required=True,
description="Text embeddings used to guide image generation",
),
"negative_prompt_embeds": InputParam(
"negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"
),
"pooled_prompt_embeds": InputParam(
"pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"
),
"negative_pooled_prompt_embeds": InputParam(
"negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"
),
"batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"),
"dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
"preprocess_kwargs": InputParam(
"preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"
),
"latents": InputParam(
"latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"
),
"timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
"num_inference_steps": InputParam(
"num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"
),
"latent_timestep": InputParam(
"latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"
),
"image_latents": InputParam(
"image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"
),
"mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"),
"masked_image_latents": InputParam(
"masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"
),
"add_time_ids": InputParam(
"add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"
),
"negative_add_time_ids": InputParam(
"negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"
),
"timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
"noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
"crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
"ip_adapter_embeds": InputParam(
"ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"
),
"negative_ip_adapter_embeds": InputParam(
"negative_ip_adapter_embeds",
type_hint=List[torch.Tensor],
description="Negative image embeddings for IP-Adapter",
),
"images": InputParam(
"images",
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
required=True,
description="Generated images",
),
}
SDXL_PARAM_SCHEMA = {**SDXL_INPUTS_SCHEMA, **SDXL_INTERMEDIATE_INPUTS_SCHEMA}
DEFAULT_PARAM_MAPS = {
"prompt": {
"label": "Prompt",
"type": "string",
"default": "a bear sitting in a chair drinking a milkshake",
"display": "textarea",
},
"negative_prompt": {
"label": "Negative Prompt",
"type": "string",
"default": "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
"display": "textarea",
},
"num_inference_steps": {
"label": "Steps",
"type": "int",
"default": 25,
"min": 1,
"max": 1000,
},
"seed": {
"label": "Seed",
"type": "int",
"default": 0,
"min": 0,
"display": "random",
},
"width": {
"label": "Width",
"type": "int",
"display": "text",
"default": 1024,
"min": 8,
"max": 8192,
"step": 8,
"group": "dimensions",
},
"height": {
"label": "Height",
"type": "int",
"display": "text",
"default": 1024,
"min": 8,
"max": 8192,
"step": 8,
"group": "dimensions",
},
"images": {
"label": "Images",
"type": "image",
"display": "output",
},
"image": {
"label": "Image",
"type": "image",
"display": "input",
},
}
DEFAULT_TYPE_MAPS = {
"int": {
"type": "int",
"default": 0,
"min": 0,
},
"float": {
"type": "float",
"default": 0.0,
"min": 0.0,
},
"str": {
"type": "string",
"default": "",
},
"bool": {
"type": "boolean",
"default": False,
},
"image": {
"type": "image",
},
}
DEFAULT_MODEL_KEYS = ["unet", "vae", "text_encoder", "tokenizer", "controlnet", "transformer", "image_encoder"]
DEFAULT_CATEGORY = "Modular Diffusers"
DEFAULT_EXCLUDE_MODEL_KEYS = ["processor", "feature_extractor", "safety_checker"]
DEFAULT_PARAMS_GROUPS_KEYS = {
"text_encoders": ["text_encoder", "tokenizer"],
"ip_adapter_embeds": ["ip_adapter_embeds"],
"prompt_embeddings": ["prompt_embeds"],
}
def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS):
"""
Get the group name for a given parameter name, if not part of a group, return None e.g. "prompt_embeds" ->
"text_embeds", "text_encoder" -> "text_encoders", "prompt" -> None
"""
if name is None:
return None
for group_name, group_keys in group_params_keys.items():
for group_key in group_keys:
if group_key in name:
return group_name
return None
class ModularNode(ConfigMixin):
"""
A ModularNode is a base class to build UI nodes using diffusers. Currently only supports Mellon. It is a wrapper
around a ModularPipelineBlocks object.
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
config_name = "node_config.json"
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
trust_remote_code: Optional[bool] = None,
**kwargs,
):
blocks = ModularPipelineBlocks.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
)
return cls(blocks, **kwargs)
def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs):
self.blocks = blocks
if label is None:
label = self.blocks.__class__.__name__
# blocks param name -> mellon param name
self.name_mapping = {}
input_params = {}
# pass or create a default param dict for each input
# e.g. for prompt,
# prompt = {
# "name": "text_input", # the name of the input in node definition, could be different from the input name in diffusers
# "label": "Prompt",
# "type": "string",
# "default": "a bear sitting in a chair drinking a milkshake",
# "display": "textarea"}
# if type is not specified, it'll be a "custom" param of its own type
# e.g. you can pass ModularNode(scheduler = {name :"scheduler"})
# it will get this spec in node definition {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}}
# name can be a dict, in that case, it is part of a "dict" input in mellon nodes, e.g. text_encoder= {name: {"text_encoders": "text_encoder"}}
inputs = self.blocks.inputs + self.blocks.intermediate_inputs
for inp in inputs:
param = kwargs.pop(inp.name, None)
if param:
# user can pass a param dict for all inputs, e.g. ModularNode(prompt = {...})
input_params[inp.name] = param
mellon_name = param.pop("name", inp.name)
if mellon_name != inp.name:
self.name_mapping[inp.name] = mellon_name
continue
if inp.name not in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name):
continue
if inp.name in DEFAULT_PARAM_MAPS:
# first check if it's in the default param map, if so, directly use that
param = DEFAULT_PARAM_MAPS[inp.name].copy()
elif get_group_name(inp.name):
param = get_group_name(inp.name)
if inp.name not in self.name_mapping:
self.name_mapping[inp.name] = param
else:
# if not, check if it's in the SDXL input schema, if so,
# 1. use the type hint to determine the type
# 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}}
if inp.type_hint is not None:
type_str = str(inp.type_hint).lower()
else:
inp_spec = SDXL_PARAM_SCHEMA.get(inp.name, None)
type_str = str(inp_spec.type_hint).lower() if inp_spec else ""
for type_key, type_param in DEFAULT_TYPE_MAPS.items():
if type_key in type_str:
param = type_param.copy()
param["label"] = inp.name
param["display"] = "input"
break
else:
param = inp.name
# add the param dict to the inp_params dict
input_params[inp.name] = param
component_params = {}
for comp in self.blocks.expected_components:
param = kwargs.pop(comp.name, None)
if param:
component_params[comp.name] = param
mellon_name = param.pop("name", comp.name)
if mellon_name != comp.name:
self.name_mapping[comp.name] = mellon_name
continue
to_exclude = False
for exclude_key in DEFAULT_EXCLUDE_MODEL_KEYS:
if exclude_key in comp.name:
to_exclude = True
break
if to_exclude:
continue
if get_group_name(comp.name):
param = get_group_name(comp.name)
if comp.name not in self.name_mapping:
self.name_mapping[comp.name] = param
elif comp.name in DEFAULT_MODEL_KEYS:
param = {"label": comp.name, "type": "diffusers_auto_model", "display": "input"}
else:
param = comp.name
# add the param dict to the model_params dict
component_params[comp.name] = param
output_params = {}
if isinstance(self.blocks, SequentialPipelineBlocks):
last_block_name = list(self.blocks.sub_blocks.keys())[-1]
outputs = self.blocks.sub_blocks[last_block_name].intermediate_outputs
else:
outputs = self.blocks.intermediate_outputs
for out in outputs:
param = kwargs.pop(out.name, None)
if param:
output_params[out.name] = param
mellon_name = param.pop("name", out.name)
if mellon_name != out.name:
self.name_mapping[out.name] = mellon_name
continue
if out.name in DEFAULT_PARAM_MAPS:
param = DEFAULT_PARAM_MAPS[out.name].copy()
param["display"] = "output"
else:
group_name = get_group_name(out.name)
if group_name:
param = group_name
if out.name not in self.name_mapping:
self.name_mapping[out.name] = param
else:
param = out.name
# add the param dict to the outputs dict
output_params[out.name] = param
if len(kwargs) > 0:
logger.warning(f"Unused kwargs: {kwargs}")
register_dict = {
"category": category,
"label": label,
"input_params": input_params,
"component_params": component_params,
"output_params": output_params,
"name_mapping": self.name_mapping,
}
self.register_to_config(**register_dict)
def setup(self, components_manager, collection=None):
self.pipeline = self.blocks.init_pipeline(components_manager=components_manager, collection=collection)
self._components_manager = components_manager
@property
def mellon_config(self):
return self._convert_to_mellon_config()
def _convert_to_mellon_config(self):
node = {}
node["label"] = self.config.label
node["category"] = self.config.category
node_param = {}
for inp_name, inp_param in self.config.input_params.items():
if inp_name in self.name_mapping:
mellon_name = self.name_mapping[inp_name]
else:
mellon_name = inp_name
if isinstance(inp_param, str):
param = {
"label": inp_param,
"type": inp_param,
"display": "input",
}
else:
param = inp_param
if mellon_name not in node_param:
node_param[mellon_name] = param
else:
logger.debug(f"Input param {mellon_name} already exists in node_param, skipping {inp_name}")
for comp_name, comp_param in self.config.component_params.items():
if comp_name in self.name_mapping:
mellon_name = self.name_mapping[comp_name]
else:
mellon_name = comp_name
if isinstance(comp_param, str):
param = {
"label": comp_param,
"type": comp_param,
"display": "input",
}
else:
param = comp_param
if mellon_name not in node_param:
node_param[mellon_name] = param
else:
logger.debug(f"Component param {comp_param} already exists in node_param, skipping {comp_name}")
for out_name, out_param in self.config.output_params.items():
if out_name in self.name_mapping:
mellon_name = self.name_mapping[out_name]
else:
mellon_name = out_name
if isinstance(out_param, str):
param = {
"label": out_param,
"type": out_param,
"display": "output",
}
else:
param = out_param
if mellon_name not in node_param:
node_param[mellon_name] = param
else:
logger.debug(f"Output param {out_param} already exists in node_param, skipping {out_name}")
node["params"] = node_param
return node
def save_mellon_config(self, file_path):
"""
Save the Mellon configuration to a JSON file.
Args:
file_path (str or Path): Path where the JSON file will be saved
Returns:
Path: Path to the saved config file
"""
file_path = Path(file_path)
# Create directory if it doesn't exist
os.makedirs(file_path.parent, exist_ok=True)
# Create a combined dictionary with module definition and name mapping
config = {"module": self.mellon_config, "name_mapping": self.name_mapping}
# Save the config to file
with open(file_path, "w", encoding="utf-8") as f:
json.dump(config, f, indent=2)
logger.info(f"Mellon config and name mapping saved to {file_path}")
return file_path
@classmethod
def load_mellon_config(cls, file_path):
"""
Load a Mellon configuration from a JSON file.
Args:
file_path (str or Path): Path to the JSON file containing Mellon config
Returns:
dict: The loaded combined configuration containing 'module' and 'name_mapping'
"""
file_path = Path(file_path)
if not file_path.exists():
raise FileNotFoundError(f"Config file not found: {file_path}")
with open(file_path, "r", encoding="utf-8") as f:
config = json.load(f)
logger.info(f"Mellon config loaded from {file_path}")
return config
def process_inputs(self, **kwargs):
params_components = {}
for comp_name, comp_param in self.config.component_params.items():
logger.debug(f"component: {comp_name}")
mellon_comp_name = self.name_mapping.get(comp_name, comp_name)
if mellon_comp_name in kwargs:
if isinstance(kwargs[mellon_comp_name], dict) and comp_name in kwargs[mellon_comp_name]:
comp = kwargs[mellon_comp_name].pop(comp_name)
else:
comp = kwargs.pop(mellon_comp_name)
if comp:
params_components[comp_name] = self._components_manager.get_one(comp["model_id"])
params_run = {}
for inp_name, inp_param in self.config.input_params.items():
logger.debug(f"input: {inp_name}")
mellon_inp_name = self.name_mapping.get(inp_name, inp_name)
if mellon_inp_name in kwargs:
if isinstance(kwargs[mellon_inp_name], dict) and inp_name in kwargs[mellon_inp_name]:
inp = kwargs[mellon_inp_name].pop(inp_name)
else:
inp = kwargs.pop(mellon_inp_name)
if inp is not None:
params_run[inp_name] = inp
return_output_names = list(self.config.output_params.keys())
return params_components, params_run, return_output_names
def execute(self, **kwargs):
params_components, params_run, return_output_names = self.process_inputs(**kwargs)
self.pipeline.update_components(**params_components)
output = self.pipeline(**params_run, output=return_output_names)
return output
......@@ -97,11 +97,7 @@ class QwenImageModularPipeline(ModularPipeline, QwenImageLoraLoaderMixin):
"""
A ModularPipeline for QwenImage.
<Tip warning={true}>
This is an experimental feature and is likely to change in the future.
</Tip>
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
default_blocks_name = "QwenImageAutoBlocks"
......@@ -153,11 +149,7 @@ class QwenImageEditModularPipeline(ModularPipeline, QwenImageLoraLoaderMixin):
"""
A ModularPipeline for QwenImage-Edit.
<Tip warning={true}>
This is an experimental feature and is likely to change in the future.
</Tip>
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
default_blocks_name = "QwenImageEditAutoBlocks"
......
......@@ -47,11 +47,7 @@ class StableDiffusionXLModularPipeline(
"""
A ModularPipeline for Stable Diffusion XL.
<Tip warning={true}>
This is an experimental feature and is likely to change in the future.
</Tip>
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
default_blocks_name = "StableDiffusionXLAutoBlocks"
......
......@@ -30,11 +30,7 @@ class WanModularPipeline(
"""
A ModularPipeline for Wan.
<Tip warning={true}>
This is an experimental feature and is likely to change in the future.
</Tip>
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
default_blocks_name = "WanAutoBlocks"
......
......@@ -407,12 +407,8 @@ class AutoPipelineForText2Image(ConfigMixin):
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`.
<Tip>
To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with `hf
auth login`.
</Tip>
> [!TIP] > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in
with `hf > auth login`.
Examples:
......@@ -702,12 +698,8 @@ class AutoPipelineForImage2Image(ConfigMixin):
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`.
<Tip>
To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with `hf
auth login`.
</Tip>
> [!TIP] > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in
with `hf > auth login`.
Examples:
......@@ -1012,12 +1004,8 @@ class AutoPipelineForInpainting(ConfigMixin):
Load weights from a specified variant filename such as `"fp16"` or `"ema"`. This is ignored when
loading `from_flax`.
<Tip>
To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in with `hf
auth login`.
</Tip>
> [!TIP] > To use private or [gated](https://huggingface.co/docs/hub/models-gated#gated-models) models, log-in
with `hf > auth login`.
Examples:
......
......@@ -146,16 +146,13 @@ class StableDiffusionControlNetInpaintPipeline(
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
<Tip>
This pipeline can be used with checkpoints that have been specifically fine-tuned for inpainting
> [!TIP] > This pipeline can be used with checkpoints that have been specifically fine-tuned for inpainting >
([stable-diffusion-v1-5/stable-diffusion-inpainting](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-inpainting))
as well as default text-to-image Stable Diffusion checkpoints
> as well as default text-to-image Stable Diffusion checkpoints >
([stable-diffusion-v1-5/stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5)).
Default text-to-image Stable Diffusion checkpoints might be preferable for ControlNets that have been fine-tuned on
those, such as [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint).
</Tip>
> Default text-to-image Stable Diffusion checkpoints might be preferable for ControlNets that have been fine-tuned
on > those, such as
[lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint).
Args:
vae ([`AutoencoderKL`]):
......
......@@ -394,12 +394,8 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
jit (`bool`, defaults to `False`):
Whether to run `pmap` versions of the generation and safety scoring functions.
<Tip warning={true}>
This argument exists because `__call__` is not yet end-to-end pmap-able. It will be removed in a
future release.
</Tip>
> [!WARNING] > This argument exists because `__call__` is not yet end-to-end pmap-able. It will be
removed in a > future release.
Examples:
......
......@@ -1000,11 +1000,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
> [!WARNING] > This API is 🧪 experimental.
"""
self.original_attn_processors = None
......@@ -1021,11 +1017,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
> [!WARNING] > This API is 🧪 experimental.
"""
if self.original_attn_processors is not None:
......
......@@ -150,17 +150,13 @@ class StableDiffusionControlNetPAGInpaintPipeline(
- [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
<Tip>
This pipeline can be used with checkpoints that have been specifically fine-tuned for inpainting
([runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting)) as well as
default text-to-image Stable Diffusion checkpoints
([runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5)). Default text-to-image
Stable Diffusion checkpoints might be preferable for ControlNets that have been fine-tuned on those, such as
> [!TIP] > This pipeline can be used with checkpoints that have been specifically fine-tuned for inpainting >
([runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting)) as well as >
default text-to-image Stable Diffusion checkpoints >
([runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5)). Default text-to-image >
Stable Diffusion checkpoints might be preferable for ControlNets that have been fine-tuned on those, such as >
[lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint).
</Tip>
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
......
......@@ -158,11 +158,7 @@ def prepare_mask_and_masked_image(image, mask):
class PaintByExamplePipeline(DeprecatedPipelineMixin, DiffusionPipeline, StableDiffusionMixin):
_last_supported_version = "0.33.1"
r"""
<Tip warning={true}>
🧪 This is an experimental feature!
</Tip>
> [!WARNING] > 🧪 This is an experimental feature!
Pipeline for image-guided image inpainting using Stable Diffusion.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment