Unverified Commit 630d27fe authored by Dhruv Nair's avatar Dhruv Nair Committed by GitHub
Browse files

[Modular] More Updates for Custom Code Loading (#11969)



* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

* update

---------
Co-authored-by: default avatarYiYi Xu <yixu310@gmail.com>
parent f442955c
...@@ -25,7 +25,6 @@ else: ...@@ -25,7 +25,6 @@ else:
_import_structure["modular_pipeline"] = [ _import_structure["modular_pipeline"] = [
"ModularPipelineBlocks", "ModularPipelineBlocks",
"ModularPipeline", "ModularPipeline",
"PipelineBlock",
"AutoPipelineBlocks", "AutoPipelineBlocks",
"SequentialPipelineBlocks", "SequentialPipelineBlocks",
"LoopSequentialPipelineBlocks", "LoopSequentialPipelineBlocks",
...@@ -59,7 +58,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -59,7 +58,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LoopSequentialPipelineBlocks, LoopSequentialPipelineBlocks,
ModularPipeline, ModularPipeline,
ModularPipelineBlocks, ModularPipelineBlocks,
PipelineBlock,
PipelineState, PipelineState,
SequentialPipelineBlocks, SequentialPipelineBlocks,
) )
......
...@@ -22,7 +22,7 @@ from ...models import AutoencoderKL ...@@ -22,7 +22,7 @@ from ...models import AutoencoderKL
from ...schedulers import FlowMatchEulerDiscreteScheduler from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import logging from ...utils import logging
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..modular_pipeline import PipelineBlock, PipelineState from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import FluxModularPipeline from .modular_pipeline import FluxModularPipeline
...@@ -231,7 +231,7 @@ def _get_initial_timesteps_and_optionals( ...@@ -231,7 +231,7 @@ def _get_initial_timesteps_and_optionals(
return timesteps, num_inference_steps, sigmas, guidance return timesteps, num_inference_steps, sigmas, guidance
class FluxInputStep(PipelineBlock): class FluxInputStep(ModularPipelineBlocks):
model_name = "flux" model_name = "flux"
@property @property
...@@ -249,11 +249,6 @@ class FluxInputStep(PipelineBlock): ...@@ -249,11 +249,6 @@ class FluxInputStep(PipelineBlock):
def inputs(self) -> List[InputParam]: def inputs(self) -> List[InputParam]:
return [ return [
InputParam("num_images_per_prompt", default=1), InputParam("num_images_per_prompt", default=1),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam( InputParam(
"prompt_embeds", "prompt_embeds",
required=True, required=True,
...@@ -322,7 +317,7 @@ class FluxInputStep(PipelineBlock): ...@@ -322,7 +317,7 @@ class FluxInputStep(PipelineBlock):
return components, state return components, state
class FluxSetTimestepsStep(PipelineBlock): class FluxSetTimestepsStep(ModularPipelineBlocks):
model_name = "flux" model_name = "flux"
@property @property
...@@ -340,14 +335,10 @@ class FluxSetTimestepsStep(PipelineBlock): ...@@ -340,14 +335,10 @@ class FluxSetTimestepsStep(PipelineBlock):
InputParam("timesteps"), InputParam("timesteps"),
InputParam("sigmas"), InputParam("sigmas"),
InputParam("guidance_scale", default=3.5), InputParam("guidance_scale", default=3.5),
InputParam("latents", type_hint=torch.Tensor),
InputParam("num_images_per_prompt", default=1), InputParam("num_images_per_prompt", default=1),
InputParam("height", type_hint=int), InputParam("height", type_hint=int),
InputParam("width", type_hint=int), InputParam("width", type_hint=int),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam( InputParam(
"batch_size", "batch_size",
required=True, required=True,
...@@ -398,7 +389,7 @@ class FluxSetTimestepsStep(PipelineBlock): ...@@ -398,7 +389,7 @@ class FluxSetTimestepsStep(PipelineBlock):
return components, state return components, state
class FluxImg2ImgSetTimestepsStep(PipelineBlock): class FluxImg2ImgSetTimestepsStep(ModularPipelineBlocks):
model_name = "flux" model_name = "flux"
@property @property
...@@ -420,11 +411,6 @@ class FluxImg2ImgSetTimestepsStep(PipelineBlock): ...@@ -420,11 +411,6 @@ class FluxImg2ImgSetTimestepsStep(PipelineBlock):
InputParam("num_images_per_prompt", default=1), InputParam("num_images_per_prompt", default=1),
InputParam("height", type_hint=int), InputParam("height", type_hint=int),
InputParam("width", type_hint=int), InputParam("width", type_hint=int),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam( InputParam(
"batch_size", "batch_size",
required=True, required=True,
...@@ -497,7 +483,7 @@ class FluxImg2ImgSetTimestepsStep(PipelineBlock): ...@@ -497,7 +483,7 @@ class FluxImg2ImgSetTimestepsStep(PipelineBlock):
return components, state return components, state
class FluxPrepareLatentsStep(PipelineBlock): class FluxPrepareLatentsStep(ModularPipelineBlocks):
model_name = "flux" model_name = "flux"
@property @property
...@@ -515,11 +501,6 @@ class FluxPrepareLatentsStep(PipelineBlock): ...@@ -515,11 +501,6 @@ class FluxPrepareLatentsStep(PipelineBlock):
InputParam("width", type_hint=int), InputParam("width", type_hint=int),
InputParam("latents", type_hint=Optional[torch.Tensor]), InputParam("latents", type_hint=Optional[torch.Tensor]),
InputParam("num_images_per_prompt", type_hint=int, default=1), InputParam("num_images_per_prompt", type_hint=int, default=1),
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("generator"), InputParam("generator"),
InputParam( InputParam(
"batch_size", "batch_size",
...@@ -621,7 +602,7 @@ class FluxPrepareLatentsStep(PipelineBlock): ...@@ -621,7 +602,7 @@ class FluxPrepareLatentsStep(PipelineBlock):
return components, state return components, state
class FluxImg2ImgPrepareLatentsStep(PipelineBlock): class FluxImg2ImgPrepareLatentsStep(ModularPipelineBlocks):
model_name = "flux" model_name = "flux"
@property @property
...@@ -639,11 +620,6 @@ class FluxImg2ImgPrepareLatentsStep(PipelineBlock): ...@@ -639,11 +620,6 @@ class FluxImg2ImgPrepareLatentsStep(PipelineBlock):
InputParam("width", type_hint=int), InputParam("width", type_hint=int),
InputParam("latents", type_hint=Optional[torch.Tensor]), InputParam("latents", type_hint=Optional[torch.Tensor]),
InputParam("num_images_per_prompt", type_hint=int, default=1), InputParam("num_images_per_prompt", type_hint=int, default=1),
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("generator"), InputParam("generator"),
InputParam( InputParam(
"image_latents", "image_latents",
......
...@@ -22,7 +22,7 @@ from ...configuration_utils import FrozenDict ...@@ -22,7 +22,7 @@ from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL from ...models import AutoencoderKL
from ...utils import logging from ...utils import logging
from ...video_processor import VaeImageProcessor from ...video_processor import VaeImageProcessor
from ..modular_pipeline import PipelineBlock, PipelineState from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
...@@ -45,7 +45,7 @@ def _unpack_latents(latents, height, width, vae_scale_factor): ...@@ -45,7 +45,7 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
return latents return latents
class FluxDecodeStep(PipelineBlock): class FluxDecodeStep(ModularPipelineBlocks):
model_name = "flux" model_name = "flux"
@property @property
...@@ -70,17 +70,12 @@ class FluxDecodeStep(PipelineBlock): ...@@ -70,17 +70,12 @@ class FluxDecodeStep(PipelineBlock):
InputParam("output_type", default="pil"), InputParam("output_type", default="pil"),
InputParam("height", default=1024), InputParam("height", default=1024),
InputParam("width", default=1024), InputParam("width", default=1024),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam( InputParam(
"latents", "latents",
required=True, required=True,
type_hint=torch.Tensor, type_hint=torch.Tensor,
description="The denoised latents from the denoising step", description="The denoised latents from the denoising step",
) ),
] ]
@property @property
......
...@@ -22,7 +22,7 @@ from ...utils import logging ...@@ -22,7 +22,7 @@ from ...utils import logging
from ..modular_pipeline import ( from ..modular_pipeline import (
BlockState, BlockState,
LoopSequentialPipelineBlocks, LoopSequentialPipelineBlocks,
PipelineBlock, ModularPipelineBlocks,
PipelineState, PipelineState,
) )
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
...@@ -32,7 +32,7 @@ from .modular_pipeline import FluxModularPipeline ...@@ -32,7 +32,7 @@ from .modular_pipeline import FluxModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class FluxLoopDenoiser(PipelineBlock): class FluxLoopDenoiser(ModularPipelineBlocks):
model_name = "flux" model_name = "flux"
@property @property
...@@ -49,11 +49,8 @@ class FluxLoopDenoiser(PipelineBlock): ...@@ -49,11 +49,8 @@ class FluxLoopDenoiser(PipelineBlock):
@property @property
def inputs(self) -> List[Tuple[str, Any]]: def inputs(self) -> List[Tuple[str, Any]]:
return [InputParam("joint_attention_kwargs")]
@property
def intermediate_inputs(self) -> List[str]:
return [ return [
InputParam("joint_attention_kwargs"),
InputParam( InputParam(
"latents", "latents",
required=True, required=True,
...@@ -113,7 +110,7 @@ class FluxLoopDenoiser(PipelineBlock): ...@@ -113,7 +110,7 @@ class FluxLoopDenoiser(PipelineBlock):
return components, block_state return components, block_state
class FluxLoopAfterDenoiser(PipelineBlock): class FluxLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "flux" model_name = "flux"
@property @property
...@@ -175,7 +172,7 @@ class FluxDenoiseLoopWrapper(LoopSequentialPipelineBlocks): ...@@ -175,7 +172,7 @@ class FluxDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
] ]
@property @property
def loop_intermediate_inputs(self) -> List[InputParam]: def loop_inputs(self) -> List[InputParam]:
return [ return [
InputParam( InputParam(
"timesteps", "timesteps",
......
...@@ -24,7 +24,7 @@ from ...image_processor import VaeImageProcessor ...@@ -24,7 +24,7 @@ from ...image_processor import VaeImageProcessor
from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL from ...models import AutoencoderKL
from ...utils import USE_PEFT_BACKEND, is_ftfy_available, logging, scale_lora_layers, unscale_lora_layers from ...utils import USE_PEFT_BACKEND, is_ftfy_available, logging, scale_lora_layers, unscale_lora_layers
from ..modular_pipeline import PipelineBlock, PipelineState from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_pipeline import FluxModularPipeline from .modular_pipeline import FluxModularPipeline
...@@ -67,7 +67,7 @@ def retrieve_latents( ...@@ -67,7 +67,7 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output") raise AttributeError("Could not access latents of provided encoder_output")
class FluxVaeEncoderStep(PipelineBlock): class FluxVaeEncoderStep(ModularPipelineBlocks):
model_name = "flux" model_name = "flux"
@property @property
...@@ -88,11 +88,10 @@ class FluxVaeEncoderStep(PipelineBlock): ...@@ -88,11 +88,10 @@ class FluxVaeEncoderStep(PipelineBlock):
@property @property
def inputs(self) -> List[InputParam]: def inputs(self) -> List[InputParam]:
return [InputParam("image", required=True), InputParam("height"), InputParam("width")]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [ return [
InputParam("image", required=True),
InputParam("height"),
InputParam("width"),
InputParam("generator"), InputParam("generator"),
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
InputParam( InputParam(
...@@ -157,7 +156,7 @@ class FluxVaeEncoderStep(PipelineBlock): ...@@ -157,7 +156,7 @@ class FluxVaeEncoderStep(PipelineBlock):
return components, state return components, state
class FluxTextEncoderStep(PipelineBlock): class FluxTextEncoderStep(ModularPipelineBlocks):
model_name = "flux" model_name = "flux"
@property @property
......
...@@ -29,11 +29,7 @@ from typing_extensions import Self ...@@ -29,11 +29,7 @@ from typing_extensions import Self
from ..configuration_utils import ConfigMixin, FrozenDict from ..configuration_utils import ConfigMixin, FrozenDict
from ..pipelines.pipeline_loading_utils import _fetch_class_library_tuple, simple_get_class_obj from ..pipelines.pipeline_loading_utils import _fetch_class_library_tuple, simple_get_class_obj
from ..utils import ( from ..utils import PushToHubMixin, is_accelerate_available, logging
PushToHubMixin,
is_accelerate_available,
logging,
)
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ..utils.hub_utils import load_or_create_model_card, populate_model_card from ..utils.hub_utils import load_or_create_model_card, populate_model_card
from .components_manager import ComponentsManager from .components_manager import ComponentsManager
...@@ -45,8 +41,6 @@ from .modular_pipeline_utils import ( ...@@ -45,8 +41,6 @@ from .modular_pipeline_utils import (
OutputParam, OutputParam,
format_components, format_components,
format_configs, format_configs,
format_inputs_short,
format_intermediates_short,
make_doc_string, make_doc_string,
) )
...@@ -80,139 +74,59 @@ class PipelineState: ...@@ -80,139 +74,59 @@ class PipelineState:
[`PipelineState`] stores the state of a pipeline. It is used to pass data between pipeline blocks. [`PipelineState`] stores the state of a pipeline. It is used to pass data between pipeline blocks.
""" """
inputs: Dict[str, Any] = field(default_factory=dict) values: Dict[str, Any] = field(default_factory=dict)
intermediates: Dict[str, Any] = field(default_factory=dict) kwargs_mapping: Dict[str, List[str]] = field(default_factory=dict)
input_kwargs: Dict[str, List[str]] = field(default_factory=dict)
intermediate_kwargs: Dict[str, List[str]] = field(default_factory=dict)
def set_input(self, key: str, value: Any, kwargs_type: str = None): def set(self, key: str, value: Any, kwargs_type: str = None):
""" """
Add an input to the immutable pipeline state, i.e, pipeline_state.inputs. Add a value to the pipeline state.
The kwargs_type parameter allows you to associate inputs with specific input types. For example, if you call
set_input(prompt_embeds=..., kwargs_type="guider_kwargs"), this input will be automatically fetched when a
pipeline block has "guider_kwargs" in its expected_inputs list.
Args: Args:
key (str): The key for the input key (str): The key for the value
value (Any): The input value value (Any): The value to store
kwargs_type (str): The kwargs_type with which the input is associated kwargs_type (str): The kwargs_type with which the value is associated
"""
self.inputs[key] = value
if kwargs_type is not None:
if kwargs_type not in self.input_kwargs:
self.input_kwargs[kwargs_type] = [key]
else:
self.input_kwargs[kwargs_type].append(key)
def set_intermediate(self, key: str, value: Any, kwargs_type: str = None):
""" """
Add an intermediate value to the mutable pipeline state, i.e, pipeline_state.intermediates. self.values[key] = value
The kwargs_type parameter allows you to associate intermediate values with specific input types. For example,
if you call set_intermediate(latents=..., kwargs_type="latents_kwargs"), this intermediate value will be
automatically fetched when a pipeline block has "latents_kwargs" in its expected_intermediate_inputs list.
Args:
key (str): The key for the intermediate value
value (Any): The intermediate value
kwargs_type (str): The kwargs_type with which the intermediate value is associated
"""
self.intermediates[key] = value
if kwargs_type is not None: if kwargs_type is not None:
if kwargs_type not in self.intermediate_kwargs: if kwargs_type not in self.kwargs_mapping:
self.intermediate_kwargs[kwargs_type] = [key] self.kwargs_mapping[kwargs_type] = [key]
else: else:
self.intermediate_kwargs[kwargs_type].append(key) self.kwargs_mapping[kwargs_type].append(key)
def get_input(self, key: str, default: Any = None) -> Any: def get(self, keys: Union[str, List[str]], default: Any = None) -> Union[Any, Dict[str, Any]]:
""" """
Get an input from the pipeline state. Get one or multiple values from the pipeline state.
Args: Args:
key (str): The key for the input keys (Union[str, List[str]]): Key or list of keys for the values
default (Any): The default value to return if the input is not found default (Any): The default value to return if not found
Returns: Returns:
Any: The input value Union[Any, Dict[str, Any]]: Single value if keys is str, dictionary of values if keys is list
""" """
value = self.inputs.get(key, default) if isinstance(keys, str):
if value is not None: return self.values.get(keys, default)
return deepcopy(value) return {key: self.values.get(key, default) for key in keys}
def get_inputs(self, keys: List[str], default: Any = None) -> Dict[str, Any]: def get_by_kwargs(self, kwargs_type: str) -> Dict[str, Any]:
""" """
Get multiple inputs from the pipeline state. Get all values with matching kwargs_type.
Args:
keys (List[str]): The keys for the inputs
default (Any): The default value to return if the input is not found
Returns:
Dict[str, Any]: Dictionary of inputs with matching keys
"""
return {key: self.inputs.get(key, default) for key in keys}
def get_inputs_kwargs(self, kwargs_type: str) -> Dict[str, Any]:
"""
Get all inputs with matching kwargs_type.
Args: Args:
kwargs_type (str): The kwargs_type to filter by kwargs_type (str): The kwargs_type to filter by
Returns: Returns:
Dict[str, Any]: Dictionary of inputs with matching kwargs_type Dict[str, Any]: Dictionary of values with matching kwargs_type
"""
input_names = self.input_kwargs.get(kwargs_type, [])
return self.get_inputs(input_names)
def get_intermediate_kwargs(self, kwargs_type: str) -> Dict[str, Any]:
""" """
Get all intermediates with matching kwargs_type. value_names = self.kwargs_mapping.get(kwargs_type, [])
return self.get(value_names)
Args:
kwargs_type (str): The kwargs_type to filter by
Returns:
Dict[str, Any]: Dictionary of intermediates with matching kwargs_type
"""
intermediate_names = self.intermediate_kwargs.get(kwargs_type, [])
return self.get_intermediates(intermediate_names)
def get_intermediate(self, key: str, default: Any = None) -> Any:
"""
Get an intermediate value from the pipeline state.
Args:
key (str): The key for the intermediate value
default (Any): The default value to return if the intermediate value is not found
Returns:
Any: The intermediate value
"""
return self.intermediates.get(key, default)
def get_intermediates(self, keys: List[str], default: Any = None) -> Dict[str, Any]:
"""
Get multiple intermediate values from the pipeline state.
Args:
keys (List[str]): The keys for the intermediate values
default (Any): The default value to return if the intermediate value is not found
Returns:
Dict[str, Any]: Dictionary of intermediate values with matching keys
"""
return {key: self.intermediates.get(key, default) for key in keys}
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
""" """
Convert PipelineState to a dictionary. Convert PipelineState to a dictionary.
Returns:
Dict[str, Any]: Dictionary containing all attributes of the PipelineState
""" """
return {**self.__dict__, "inputs": self.inputs, "intermediates": self.intermediates} return {**self.__dict__}
def __repr__(self): def __repr__(self):
def format_value(v): def format_value(v):
...@@ -223,21 +137,10 @@ class PipelineState: ...@@ -223,21 +137,10 @@ class PipelineState:
else: else:
return repr(v) return repr(v)
inputs = "\n".join(f" {k}: {format_value(v)}" for k, v in self.inputs.items()) values_str = "\n".join(f" {k}: {format_value(v)}" for k, v in self.values.items())
intermediates = "\n".join(f" {k}: {format_value(v)}" for k, v in self.intermediates.items()) kwargs_mapping_str = "\n".join(f" {k}: {v}" for k, v in self.kwargs_mapping.items())
# Format input_kwargs and intermediate_kwargs return f"PipelineState(\n values={{\n{values_str}\n }},\n kwargs_mapping={{\n{kwargs_mapping_str}\n }}\n)"
input_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.input_kwargs.items())
intermediate_kwargs_str = "\n".join(f" {k}: {v}" for k, v in self.intermediate_kwargs.items())
return (
f"PipelineState(\n"
f" inputs={{\n{inputs}\n }},\n"
f" intermediates={{\n{intermediates}\n }},\n"
f" input_kwargs={{\n{input_kwargs_str}\n }},\n"
f" intermediate_kwargs={{\n{intermediate_kwargs_str}\n }}\n"
f")"
)
@dataclass @dataclass
...@@ -326,7 +229,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): ...@@ -326,7 +229,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
</Tip> </Tip>
""" """
config_name = "config.json" config_name = "modular_config.json"
model_name = None model_name = None
@classmethod @classmethod
...@@ -338,6 +241,14 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): ...@@ -338,6 +241,14 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
return expected_modules, optional_parameters return expected_modules, optional_parameters
def __init__(self):
self.sub_blocks = InsertableDict()
@property
def description(self) -> str:
"""Description of the block. Must be implemented by subclasses."""
return ""
@property @property
def expected_components(self) -> List[ComponentSpec]: def expected_components(self) -> List[ComponentSpec]:
return [] return []
...@@ -346,6 +257,35 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): ...@@ -346,6 +257,35 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
def expected_configs(self) -> List[ConfigSpec]: def expected_configs(self) -> List[ConfigSpec]:
return [] return []
@property
def inputs(self) -> List[InputParam]:
"""List of input parameters. Must be implemented by subclasses."""
return []
def _get_required_inputs(self):
input_names = []
for input_param in self.inputs:
if input_param.required:
input_names.append(input_param.name)
return input_names
@property
def required_inputs(self) -> List[InputParam]:
return self._get_required_inputs()
@property
def intermediate_outputs(self) -> List[OutputParam]:
"""List of intermediate output parameters. Must be implemented by subclasses."""
return []
def _get_outputs(self):
return self.intermediate_outputs
@property
def outputs(self) -> List[OutputParam]:
return self._get_outputs()
@classmethod @classmethod
def from_pretrained( def from_pretrained(
cls, cls,
...@@ -427,6 +367,63 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): ...@@ -427,6 +367,63 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
) )
return modular_pipeline return modular_pipeline
def get_block_state(self, state: PipelineState) -> dict:
"""Get all inputs and intermediates in one dictionary"""
data = {}
state_inputs = self.inputs
# Check inputs
for input_param in state_inputs:
if input_param.name:
value = state.get(input_param.name)
if input_param.required and value is None:
raise ValueError(f"Required input '{input_param.name}' is missing")
elif value is not None or (value is None and input_param.name not in data):
data[input_param.name] = value
elif input_param.kwargs_type:
# if kwargs_type is provided, get all inputs with matching kwargs_type
if input_param.kwargs_type not in data:
data[input_param.kwargs_type] = {}
inputs_kwargs = state.get_by_kwargs(input_param.kwargs_type)
if inputs_kwargs:
for k, v in inputs_kwargs.items():
if v is not None:
data[k] = v
data[input_param.kwargs_type][k] = v
return BlockState(**data)
def set_block_state(self, state: PipelineState, block_state: BlockState):
for output_param in self.intermediate_outputs:
if not hasattr(block_state, output_param.name):
raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
param = getattr(block_state, output_param.name)
state.set(output_param.name, param, output_param.kwargs_type)
for input_param in self.inputs:
if input_param.name and hasattr(block_state, input_param.name):
param = getattr(block_state, input_param.name)
# Only add if the value is different from what's in the state
current_value = state.get(input_param.name)
if current_value is not param: # Using identity comparison to check if object was modified
state.set(input_param.name, param, input_param.kwargs_type)
elif input_param.kwargs_type:
# if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
# we need to first find out which inputs are and loop through them.
intermediate_kwargs = state.get_by_kwargs(input_param.kwargs_type)
for param_name, current_value in intermediate_kwargs.items():
if param_name is None:
continue
if not hasattr(block_state, param_name):
continue
param = getattr(block_state, param_name)
if current_value is not param: # Using identity comparison to check if object was modified
state.set(param_name, param, input_param.kwargs_type)
@staticmethod @staticmethod
def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]:
""" """
...@@ -497,10 +494,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): ...@@ -497,10 +494,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
def input_names(self) -> List[str]: def input_names(self) -> List[str]:
return [input_param.name for input_param in self.inputs] return [input_param.name for input_param in self.inputs]
@property
def intermediate_input_names(self) -> List[str]:
return [input_param.name for input_param in self.intermediate_inputs]
@property @property
def intermediate_output_names(self) -> List[str]: def intermediate_output_names(self) -> List[str]:
return [output_param.name for output_param in self.intermediate_outputs] return [output_param.name for output_param in self.intermediate_outputs]
...@@ -509,162 +502,10 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): ...@@ -509,162 +502,10 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
def output_names(self) -> List[str]: def output_names(self) -> List[str]:
return [output_param.name for output_param in self.outputs] return [output_param.name for output_param in self.outputs]
class PipelineBlock(ModularPipelineBlocks):
"""
A Pipeline Block is the basic building block of a Modular Pipeline.
This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
library implements for all the pipeline blocks (such as loading or saving etc.)
<Tip warning={true}>
This is an experimental feature and is likely to change in the future.
</Tip>
Args:
description (str, optional): A description of the block, defaults to None. Define as a property in subclasses.
expected_components (List[ComponentSpec], optional):
A list of components that are expected to be used in the block, defaults to []. To override, define as a
property in subclasses.
expected_configs (List[ConfigSpec], optional):
A list of configs that are expected to be used in the block, defaults to []. To override, define as a
property in subclasses.
inputs (List[InputParam], optional):
A list of inputs that are expected to be used in the block, defaults to []. To override, define as a
property in subclasses.
intermediate_inputs (List[InputParam], optional):
A list of intermediate inputs that are expected to be used in the block, defaults to []. To override,
define as a property in subclasses.
intermediate_outputs (List[OutputParam], optional):
A list of intermediate outputs that are expected to be used in the block, defaults to []. To override,
define as a property in subclasses.
outputs (List[OutputParam], optional):
A list of outputs that are expected to be used in the block, defaults to []. To override, define as a
property in subclasses.
required_inputs (List[str], optional):
A list of required inputs that are expected to be used in the block, defaults to []. To override, define as
a property in subclasses.
required_intermediate_inputs (List[str], optional):
A list of required intermediate inputs that are expected to be used in the block, defaults to []. To
override, define as a property in subclasses.
required_intermediate_outputs (List[str], optional):
A list of required intermediate outputs that are expected to be used in the block, defaults to []. To
override, define as a property in subclasses.
"""
model_name = None
def __init__(self):
self.sub_blocks = InsertableDict()
@property
def description(self) -> str:
"""Description of the block. Must be implemented by subclasses."""
# raise NotImplementedError("description method must be implemented in subclasses")
return "TODO: add a description"
@property
def expected_components(self) -> List[ComponentSpec]:
return []
@property
def expected_configs(self) -> List[ConfigSpec]:
return []
@property
def inputs(self) -> List[InputParam]:
"""List of input parameters. Must be implemented by subclasses."""
return []
@property
def intermediate_inputs(self) -> List[InputParam]:
"""List of intermediate input parameters. Must be implemented by subclasses."""
return []
@property
def intermediate_outputs(self) -> List[OutputParam]:
"""List of intermediate output parameters. Must be implemented by subclasses."""
return []
def _get_outputs(self):
return self.intermediate_outputs
# YiYi TODO: is it too easy for user to unintentionally override these properties?
# Adding outputs attributes here for consistency between PipelineBlock/AutoPipelineBlocks/SequentialPipelineBlocks
@property
def outputs(self) -> List[OutputParam]:
return self._get_outputs()
def _get_required_inputs(self):
input_names = []
for input_param in self.inputs:
if input_param.required:
input_names.append(input_param.name)
return input_names
@property
def required_inputs(self) -> List[str]:
return self._get_required_inputs()
def _get_required_intermediate_inputs(self):
input_names = []
for input_param in self.intermediate_inputs:
if input_param.required:
input_names.append(input_param.name)
return input_names
# YiYi TODO: maybe we do not need this, it is only used in docstring,
# intermediate_inputs is by default required, unless you manually handle it inside the block
@property
def required_intermediate_inputs(self) -> List[str]:
return self._get_required_intermediate_inputs()
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
raise NotImplementedError("__call__ method must be implemented in subclasses")
def __repr__(self):
class_name = self.__class__.__name__
base_class = self.__class__.__bases__[0].__name__
# Format description with proper indentation
desc_lines = self.description.split("\n")
desc = []
# First line with "Description:" label
desc.append(f" Description: {desc_lines[0]}")
# Subsequent lines with proper indentation
if len(desc_lines) > 1:
desc.extend(f" {line}" for line in desc_lines[1:])
desc = "\n".join(desc) + "\n"
# Components section - use format_components with add_empty_lines=False
expected_components = getattr(self, "expected_components", [])
components_str = format_components(expected_components, indent_level=2, add_empty_lines=False)
components = " " + components_str.replace("\n", "\n ")
# Configs section - use format_configs with add_empty_lines=False
expected_configs = getattr(self, "expected_configs", [])
configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False)
configs = " " + configs_str.replace("\n", "\n ")
# Inputs section
inputs_str = format_inputs_short(self.inputs)
inputs = "Inputs:\n " + inputs_str
# Intermediates section
intermediates_str = format_intermediates_short(
self.intermediate_inputs, self.required_intermediate_inputs, self.intermediate_outputs
)
intermediates = f"Intermediates:\n{intermediates_str}"
return f"{class_name}(\n Class: {base_class}\n{desc}{components}\n{configs}\n {inputs}\n {intermediates}\n)"
@property @property
def doc(self): def doc(self):
return make_doc_string( return make_doc_string(
self.inputs, self.inputs,
self.intermediate_inputs,
self.outputs, self.outputs,
self.description, self.description,
class_name=self.__class__.__name__, class_name=self.__class__.__name__,
...@@ -672,82 +513,6 @@ class PipelineBlock(ModularPipelineBlocks): ...@@ -672,82 +513,6 @@ class PipelineBlock(ModularPipelineBlocks):
expected_configs=self.expected_configs, expected_configs=self.expected_configs,
) )
# YiYi TODO: input and inteermediate inputs with same name? should warn?
def get_block_state(self, state: PipelineState) -> dict:
"""Get all inputs and intermediates in one dictionary"""
data = {}
# Check inputs
for input_param in self.inputs:
if input_param.name:
value = state.get_input(input_param.name)
if input_param.required and value is None:
raise ValueError(f"Required input '{input_param.name}' is missing")
elif value is not None or (value is None and input_param.name not in data):
data[input_param.name] = value
elif input_param.kwargs_type:
# if kwargs_type is provided, get all inputs with matching kwargs_type
if input_param.kwargs_type not in data:
data[input_param.kwargs_type] = {}
inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type)
if inputs_kwargs:
for k, v in inputs_kwargs.items():
if v is not None:
data[k] = v
data[input_param.kwargs_type][k] = v
# Check intermediates
for input_param in self.intermediate_inputs:
if input_param.name:
value = state.get_intermediate(input_param.name)
if input_param.required and value is None:
raise ValueError(f"Required intermediate input '{input_param.name}' is missing")
elif value is not None or (value is None and input_param.name not in data):
data[input_param.name] = value
elif input_param.kwargs_type:
# if kwargs_type is provided, get all intermediates with matching kwargs_type
if input_param.kwargs_type not in data:
data[input_param.kwargs_type] = {}
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
if intermediate_kwargs:
for k, v in intermediate_kwargs.items():
if v is not None:
if k not in data:
data[k] = v
data[input_param.kwargs_type][k] = v
return BlockState(**data)
def set_block_state(self, state: PipelineState, block_state: BlockState):
for output_param in self.intermediate_outputs:
if not hasattr(block_state, output_param.name):
raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
param = getattr(block_state, output_param.name)
state.set_intermediate(output_param.name, param, output_param.kwargs_type)
for input_param in self.intermediate_inputs:
if hasattr(block_state, input_param.name):
param = getattr(block_state, input_param.name)
# Only add if the value is different from what's in the state
current_value = state.get_intermediate(input_param.name)
if current_value is not param: # Using identity comparison to check if object was modified
state.set_intermediate(input_param.name, param, input_param.kwargs_type)
for input_param in self.intermediate_inputs:
if input_param.name and hasattr(block_state, input_param.name):
param = getattr(block_state, input_param.name)
# Only add if the value is different from what's in the state
current_value = state.get_intermediate(input_param.name)
if current_value is not param: # Using identity comparison to check if object was modified
state.set_intermediate(input_param.name, param, input_param.kwargs_type)
elif input_param.kwargs_type:
# if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
# we need to first find out which inputs are and loop through them.
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
for param_name, current_value in intermediate_kwargs.items():
param = getattr(block_state, param_name)
if current_value is not param: # Using identity comparison to check if object was modified
state.set_intermediate(param_name, param, input_param.kwargs_type)
class AutoPipelineBlocks(ModularPipelineBlocks): class AutoPipelineBlocks(ModularPipelineBlocks):
""" """
...@@ -837,22 +602,6 @@ class AutoPipelineBlocks(ModularPipelineBlocks): ...@@ -837,22 +602,6 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
return list(required_by_all) return list(required_by_all)
# YiYi TODO: maybe we do not need this, it is only used in docstring,
# intermediate_inputs is by default required, unless you manually handle it inside the block
@property
def required_intermediate_inputs(self) -> List[str]:
if None not in self.block_trigger_inputs:
return []
first_block = next(iter(self.sub_blocks.values()))
required_by_all = set(getattr(first_block, "required_intermediate_inputs", set()))
# Intersect with required inputs from all other blocks
for block in list(self.sub_blocks.values())[1:]:
block_required = set(getattr(block, "required_intermediate_inputs", set()))
required_by_all.intersection_update(block_required)
return list(required_by_all)
# YiYi TODO: add test for this # YiYi TODO: add test for this
@property @property
def inputs(self) -> List[Tuple[str, Any]]: def inputs(self) -> List[Tuple[str, Any]]:
...@@ -866,18 +615,6 @@ class AutoPipelineBlocks(ModularPipelineBlocks): ...@@ -866,18 +615,6 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
input_param.required = False input_param.required = False
return combined_inputs return combined_inputs
@property
def intermediate_inputs(self) -> List[str]:
named_inputs = [(name, block.intermediate_inputs) for name, block in self.sub_blocks.items()]
combined_inputs = self.combine_inputs(*named_inputs)
# mark Required inputs only if that input is required by all the blocks
for input_param in combined_inputs:
if input_param.name in self.required_intermediate_inputs:
input_param.required = True
else:
input_param.required = False
return combined_inputs
@property @property
def intermediate_outputs(self) -> List[str]: def intermediate_outputs(self) -> List[str]:
named_outputs = [(name, block.intermediate_outputs) for name, block in self.sub_blocks.items()] named_outputs = [(name, block.intermediate_outputs) for name, block in self.sub_blocks.items()]
...@@ -896,10 +633,7 @@ class AutoPipelineBlocks(ModularPipelineBlocks): ...@@ -896,10 +633,7 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
block = self.trigger_to_block_map.get(None) block = self.trigger_to_block_map.get(None)
for input_name in self.block_trigger_inputs: for input_name in self.block_trigger_inputs:
if input_name is not None and state.get_input(input_name) is not None: if input_name is not None and state.get(input_name) is not None:
block = self.trigger_to_block_map[input_name]
break
elif input_name is not None and state.get_intermediate(input_name) is not None:
block = self.trigger_to_block_map[input_name] block = self.trigger_to_block_map[input_name]
break break
...@@ -1030,7 +764,6 @@ class AutoPipelineBlocks(ModularPipelineBlocks): ...@@ -1030,7 +764,6 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
def doc(self): def doc(self):
return make_doc_string( return make_doc_string(
self.inputs, self.inputs,
self.intermediate_inputs,
self.outputs, self.outputs,
self.description, self.description,
class_name=self.__class__.__name__, class_name=self.__class__.__name__,
...@@ -1067,7 +800,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks): ...@@ -1067,7 +800,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
@property @property
def model_name(self): def model_name(self):
return next(iter(self.sub_blocks.values())).model_name return next((block.model_name for block in self.sub_blocks.values() if block.model_name is not None), None)
@property @property
def expected_components(self): def expected_components(self):
...@@ -1118,78 +851,52 @@ class SequentialPipelineBlocks(ModularPipelineBlocks): ...@@ -1118,78 +851,52 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
sub_blocks[block_name] = block_cls() sub_blocks[block_name] = block_cls()
self.sub_blocks = sub_blocks self.sub_blocks = sub_blocks
@property def _get_inputs(self):
def required_inputs(self) -> List[str]:
# Get the first block from the dictionary
first_block = next(iter(self.sub_blocks.values()))
required_by_any = set(getattr(first_block, "required_inputs", set()))
# Union with required inputs from all other blocks
for block in list(self.sub_blocks.values())[1:]:
block_required = set(getattr(block, "required_inputs", set()))
required_by_any.update(block_required)
return list(required_by_any)
# YiYi TODO: maybe we do not need this, it is only used in docstring,
# intermediate_inputs is by default required, unless you manually handle it inside the block
@property
def required_intermediate_inputs(self) -> List[str]:
required_intermediate_inputs = []
for input_param in self.intermediate_inputs:
if input_param.required:
required_intermediate_inputs.append(input_param.name)
return required_intermediate_inputs
# YiYi TODO: add test for this
@property
def inputs(self) -> List[Tuple[str, Any]]:
return self.get_inputs()
def get_inputs(self):
named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()]
combined_inputs = self.combine_inputs(*named_inputs)
# mark Required inputs only if that input is required any of the blocks
for input_param in combined_inputs:
if input_param.name in self.required_inputs:
input_param.required = True
else:
input_param.required = False
return combined_inputs
@property
def intermediate_inputs(self) -> List[str]:
return self.get_intermediate_inputs()
def get_intermediate_inputs(self):
inputs = [] inputs = []
outputs = set() outputs = set()
added_inputs = set()
# Go through all blocks in order # Go through all blocks in order
for block in self.sub_blocks.values(): for block in self.sub_blocks.values():
# Add inputs that aren't in outputs yet # Add inputs that aren't in outputs yet
for inp in block.intermediate_inputs: for inp in block.inputs:
if inp.name not in outputs and inp.name not in added_inputs: if inp.name not in outputs and inp.name not in {input.name for input in inputs}:
inputs.append(inp) inputs.append(inp)
added_inputs.add(inp.name)
# Only add outputs if the block cannot be skipped # Only add outputs if the block cannot be skipped
should_add_outputs = True should_add_outputs = True
if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs: if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs:
should_add_outputs = False should_add_outputs = False
if should_add_outputs: if should_add_outputs:
# Add this block's outputs # Add this block's outputs
block_intermediate_outputs = [out.name for out in block.intermediate_outputs] block_intermediate_outputs = [out.name for out in block.intermediate_outputs]
outputs.update(block_intermediate_outputs) outputs.update(block_intermediate_outputs)
return inputs
return inputs
# YiYi TODO: add test for this
@property
def inputs(self) -> List[Tuple[str, Any]]:
return self._get_inputs()
@property
def required_inputs(self) -> List[str]:
# Get the first block from the dictionary
first_block = next(iter(self.sub_blocks.values()))
required_by_any = set(getattr(first_block, "required_inputs", set()))
# Union with required inputs from all other blocks
for block in list(self.sub_blocks.values())[1:]:
block_required = set(getattr(block, "required_inputs", set()))
required_by_any.update(block_required)
return list(required_by_any)
@property @property
def intermediate_outputs(self) -> List[str]: def intermediate_outputs(self) -> List[str]:
named_outputs = [] named_outputs = []
for name, block in self.sub_blocks.items(): for name, block in self.sub_blocks.items():
inp_names = {inp.name for inp in block.intermediate_inputs} inp_names = {inp.name for inp in block.inputs}
# so we only need to list new variables as intermediate_outputs, but if user wants to list these they modified it's still fine (a.k.a we don't enforce) # so we only need to list new variables as intermediate_outputs, but if user wants to list these they modified it's still fine (a.k.a we don't enforce)
# filter out them here so they do not end up as intermediate_outputs # filter out them here so they do not end up as intermediate_outputs
if name not in inp_names: if name not in inp_names:
...@@ -1407,7 +1114,6 @@ class SequentialPipelineBlocks(ModularPipelineBlocks): ...@@ -1407,7 +1114,6 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
def doc(self): def doc(self):
return make_doc_string( return make_doc_string(
self.inputs, self.inputs,
self.intermediate_inputs,
self.outputs, self.outputs,
self.description, self.description,
class_name=self.__class__.__name__, class_name=self.__class__.__name__,
...@@ -1457,16 +1163,6 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks): ...@@ -1457,16 +1163,6 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
"""List of input parameters. Must be implemented by subclasses.""" """List of input parameters. Must be implemented by subclasses."""
return [] return []
@property
def loop_intermediate_inputs(self) -> List[InputParam]:
"""List of intermediate input parameters. Must be implemented by subclasses."""
return []
@property
def loop_intermediate_outputs(self) -> List[OutputParam]:
"""List of intermediate output parameters. Must be implemented by subclasses."""
return []
@property @property
def loop_required_inputs(self) -> List[str]: def loop_required_inputs(self) -> List[str]:
input_names = [] input_names = []
...@@ -1476,12 +1172,9 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks): ...@@ -1476,12 +1172,9 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
return input_names return input_names
@property @property
def loop_required_intermediate_inputs(self) -> List[str]: def loop_intermediate_outputs(self) -> List[OutputParam]:
input_names = [] """List of intermediate output parameters. Must be implemented by subclasses."""
for input_param in self.loop_intermediate_inputs: return []
if input_param.required:
input_names.append(input_param.name)
return input_names
# modified from SequentialPipelineBlocks to include loop_expected_components # modified from SequentialPipelineBlocks to include loop_expected_components
@property @property
...@@ -1509,43 +1202,16 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks): ...@@ -1509,43 +1202,16 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
expected_configs.append(config) expected_configs.append(config)
return expected_configs return expected_configs
# modified from SequentialPipelineBlocks to include loop_inputs def _get_inputs(self):
def get_inputs(self):
named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()]
named_inputs.append(("loop", self.loop_inputs))
combined_inputs = self.combine_inputs(*named_inputs)
# mark Required inputs only if that input is required any of the blocks
for input_param in combined_inputs:
if input_param.name in self.required_inputs:
input_param.required = True
else:
input_param.required = False
return combined_inputs
@property
# Copied from diffusers.modular_pipelines.modular_pipeline.SequentialPipelineBlocks.inputs
def inputs(self):
return self.get_inputs()
# modified from SequentialPipelineBlocks to include loop_intermediate_inputs
@property
def intermediate_inputs(self):
intermediates = self.get_intermediate_inputs()
intermediate_names = [input.name for input in intermediates]
for loop_intermediate_input in self.loop_intermediate_inputs:
if loop_intermediate_input.name not in intermediate_names:
intermediates.append(loop_intermediate_input)
return intermediates
# modified from SequentialPipelineBlocks
def get_intermediate_inputs(self):
inputs = [] inputs = []
inputs.extend(self.loop_inputs)
outputs = set() outputs = set()
# Go through all blocks in order for name, block in self.sub_blocks.items():
for block in self.sub_blocks.values():
# Add inputs that aren't in outputs yet # Add inputs that aren't in outputs yet
inputs.extend(input_name for input_name in block.intermediate_inputs if input_name.name not in outputs) for inp in block.inputs:
if inp.name not in outputs and inp not in inputs:
inputs.append(inp)
# Only add outputs if the block cannot be skipped # Only add outputs if the block cannot be skipped
should_add_outputs = True should_add_outputs = True
...@@ -1556,8 +1222,20 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks): ...@@ -1556,8 +1222,20 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
# Add this block's outputs # Add this block's outputs
block_intermediate_outputs = [out.name for out in block.intermediate_outputs] block_intermediate_outputs = [out.name for out in block.intermediate_outputs]
outputs.update(block_intermediate_outputs) outputs.update(block_intermediate_outputs)
for input_param in inputs:
if input_param.name in self.required_inputs:
input_param.required = True
else:
input_param.required = False
return inputs return inputs
@property
# Copied from diffusers.modular_pipelines.modular_pipeline.SequentialPipelineBlocks.inputs
def inputs(self):
return self._get_inputs()
# modified from SequentialPipelineBlocks, if any additionan input required by the loop is required by the block # modified from SequentialPipelineBlocks, if any additionan input required by the loop is required by the block
@property @property
def required_inputs(self) -> List[str]: def required_inputs(self) -> List[str]:
...@@ -1575,19 +1253,6 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks): ...@@ -1575,19 +1253,6 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
return list(required_by_any) return list(required_by_any)
# YiYi TODO: maybe we do not need this, it is only used in docstring,
# intermediate_inputs is by default required, unless you manually handle it inside the block
@property
def required_intermediate_inputs(self) -> List[str]:
required_intermediate_inputs = []
for input_param in self.intermediate_inputs:
if input_param.required:
required_intermediate_inputs.append(input_param.name)
for input_param in self.loop_intermediate_inputs:
if input_param.required:
required_intermediate_inputs.append(input_param.name)
return required_intermediate_inputs
# YiYi TODO: this need to be thought about more # YiYi TODO: this need to be thought about more
# modified from SequentialPipelineBlocks to include loop_intermediate_outputs # modified from SequentialPipelineBlocks to include loop_intermediate_outputs
@property @property
...@@ -1653,80 +1318,10 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks): ...@@ -1653,80 +1318,10 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
def __call__(self, components, state: PipelineState) -> PipelineState: def __call__(self, components, state: PipelineState) -> PipelineState:
raise NotImplementedError("`__call__` method needs to be implemented by the subclass") raise NotImplementedError("`__call__` method needs to be implemented by the subclass")
def get_block_state(self, state: PipelineState) -> dict:
"""Get all inputs and intermediates in one dictionary"""
data = {}
# Check inputs
for input_param in self.inputs:
if input_param.name:
value = state.get_input(input_param.name)
if input_param.required and value is None:
raise ValueError(f"Required input '{input_param.name}' is missing")
elif value is not None or (value is None and input_param.name not in data):
data[input_param.name] = value
elif input_param.kwargs_type:
# if kwargs_type is provided, get all inputs with matching kwargs_type
if input_param.kwargs_type not in data:
data[input_param.kwargs_type] = {}
inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type)
if inputs_kwargs:
for k, v in inputs_kwargs.items():
if v is not None:
data[k] = v
data[input_param.kwargs_type][k] = v
# Check intermediates
for input_param in self.intermediate_inputs:
if input_param.name:
value = state.get_intermediate(input_param.name)
if input_param.required and value is None:
raise ValueError(f"Required intermediate input '{input_param.name}' is missing.")
elif value is not None or (value is None and input_param.name not in data):
data[input_param.name] = value
elif input_param.kwargs_type:
# if kwargs_type is provided, get all intermediates with matching kwargs_type
if input_param.kwargs_type not in data:
data[input_param.kwargs_type] = {}
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
if intermediate_kwargs:
for k, v in intermediate_kwargs.items():
if v is not None:
if k not in data:
data[k] = v
data[input_param.kwargs_type][k] = v
return BlockState(**data)
def set_block_state(self, state: PipelineState, block_state: BlockState):
for output_param in self.intermediate_outputs:
if not hasattr(block_state, output_param.name):
raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
param = getattr(block_state, output_param.name)
state.set_intermediate(output_param.name, param, output_param.kwargs_type)
for input_param in self.intermediate_inputs:
if input_param.name and hasattr(block_state, input_param.name):
param = getattr(block_state, input_param.name)
# Only add if the value is different from what's in the state
current_value = state.get_intermediate(input_param.name)
if current_value is not param: # Using identity comparison to check if object was modified
state.set_intermediate(input_param.name, param, input_param.kwargs_type)
elif input_param.kwargs_type:
# if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
# we need to first find out which inputs are and loop through them.
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
for param_name, current_value in intermediate_kwargs.items():
if not hasattr(block_state, param_name):
continue
param = getattr(block_state, param_name)
if current_value is not param: # Using identity comparison to check if object was modified
state.set_intermediate(param_name, param, input_param.kwargs_type)
@property @property
def doc(self): def doc(self):
return make_doc_string( return make_doc_string(
self.inputs, self.inputs,
self.intermediate_inputs,
self.outputs, self.outputs,
self.description, self.description,
class_name=self.__class__.__name__, class_name=self.__class__.__name__,
...@@ -1946,97 +1541,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin): ...@@ -1946,97 +1541,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
params[input_param.name] = input_param.default params[input_param.name] = input_param.default
return params return params
def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs):
"""
Execute the pipeline by running the pipeline blocks with the given inputs.
Args:
state (`PipelineState`, optional):
PipelineState instance contains inputs and intermediate values. If None, a new `PipelineState` will be
created based on the user inputs and the pipeline blocks's requirement.
output (`str` or `List[str]`, optional):
Optional specification of what to return:
- None: Returns the complete `PipelineState` with all inputs and intermediates (default)
- str: Returns a specific intermediate value from the state (e.g. `output="image"`)
- List[str]: Returns a dictionary of specific intermediate values (e.g. `output=["image",
"latents"]`)
Examples:
```python
# Get complete pipeline state
state = pipeline(prompt="A beautiful sunset", num_inference_steps=20)
print(state.intermediates) # All intermediate outputs
# Get specific output
image = pipeline(prompt="A beautiful sunset", output="image")
# Get multiple specific outputs
results = pipeline(prompt="A beautiful sunset", output=["image", "latents"])
image, latents = results["image"], results["latents"]
# Continue from previous state
state = pipeline(prompt="A beautiful sunset")
new_state = pipeline(state=state, output="image") # Continue processing
```
Returns:
- If `output` is None: Complete `PipelineState` containing all inputs and intermediates
- If `output` is str: The specific intermediate value from the state (e.g. `output="image"`)
- If `output` is List[str]: Dictionary mapping output names to their values from the state (e.g.
`output=["image", "latents"]`)
"""
if state is None:
state = PipelineState()
# Make a copy of the input kwargs
passed_kwargs = kwargs.copy()
# Add inputs to state, using defaults if not provided in the kwargs or the state
# if same input already in the state, will override it if provided in the kwargs
intermediate_inputs = [inp.name for inp in self.blocks.intermediate_inputs]
for expected_input_param in self.blocks.inputs:
name = expected_input_param.name
default = expected_input_param.default
kwargs_type = expected_input_param.kwargs_type
if name in passed_kwargs:
if name not in intermediate_inputs:
state.set_input(name, passed_kwargs.pop(name), kwargs_type)
else:
state.set_input(name, passed_kwargs[name], kwargs_type)
elif name not in state.inputs:
state.set_input(name, default, kwargs_type)
for expected_intermediate_param in self.blocks.intermediate_inputs:
name = expected_intermediate_param.name
kwargs_type = expected_intermediate_param.kwargs_type
if name in passed_kwargs:
state.set_intermediate(name, passed_kwargs.pop(name), kwargs_type)
# Warn about unexpected inputs
if len(passed_kwargs) > 0:
warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.")
# Run the pipeline
with torch.no_grad():
try:
_, state = self.blocks(self, state)
except Exception:
error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n"
logger.error(error_msg)
raise
if output is None:
return state
elif isinstance(output, str):
return state.get_intermediate(output)
elif isinstance(output, (list, tuple)):
return state.get_intermediates(output)
else:
raise ValueError(f"Output '{output}' is not a valid output type")
def load_default_components(self, **kwargs): def load_default_components(self, **kwargs):
""" """
Load from_pretrained components using the loading specs in the config dict. Load from_pretrained components using the loading specs in the config dict.
...@@ -2860,3 +2364,83 @@ class ModularPipeline(ConfigMixin, PushToHubMixin): ...@@ -2860,3 +2364,83 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
for sub_block_name, sub_block in self.blocks.sub_blocks.items(): for sub_block_name, sub_block in self.blocks.sub_blocks.items():
if hasattr(sub_block, "set_progress_bar_config"): if hasattr(sub_block, "set_progress_bar_config"):
sub_block.set_progress_bar_config(**kwargs) sub_block.set_progress_bar_config(**kwargs)
def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs):
"""
Execute the pipeline by running the pipeline blocks with the given inputs.
Args:
state (`PipelineState`, optional):
PipelineState instance contains inputs and intermediate values. If None, a new `PipelineState` will be
created based on the user inputs and the pipeline blocks's requirement.
output (`str` or `List[str]`, optional):
Optional specification of what to return:
- None: Returns the complete `PipelineState` with all inputs and intermediates (default)
- str: Returns a specific intermediate value from the state (e.g. `output="image"`)
- List[str]: Returns a dictionary of specific intermediate values (e.g. `output=["image",
"latents"]`)
Examples:
```python
# Get complete pipeline state
state = pipeline(prompt="A beautiful sunset", num_inference_steps=20)
print(state.intermediates) # All intermediate outputs
# Get specific output
image = pipeline(prompt="A beautiful sunset", output="image")
# Get multiple specific outputs
results = pipeline(prompt="A beautiful sunset", output=["image", "latents"])
image, latents = results["image"], results["latents"]
# Continue from previous state
state = pipeline(prompt="A beautiful sunset")
new_state = pipeline(state=state, output="image") # Continue processing
```
Returns:
- If `output` is None: Complete `PipelineState` containing all inputs and intermediates
- If `output` is str: The specific intermediate value from the state (e.g. `output="image"`)
- If `output` is List[str]: Dictionary mapping output names to their values from the state (e.g.
`output=["image", "latents"]`)
"""
if state is None:
state = PipelineState()
# Make a copy of the input kwargs
passed_kwargs = kwargs.copy()
# Add inputs to state, using defaults if not provided in the kwargs or the state
# if same input already in the state, will override it if provided in the kwargs
for expected_input_param in self.blocks.inputs:
name = expected_input_param.name
default = expected_input_param.default
kwargs_type = expected_input_param.kwargs_type
if name in passed_kwargs:
state.set(name, passed_kwargs.pop(name), kwargs_type)
elif name not in state.values:
state.set(name, default, kwargs_type)
# Warn about unexpected inputs
if len(passed_kwargs) > 0:
warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.")
# Run the pipeline
with torch.no_grad():
try:
_, state = self.blocks(self, state)
except Exception:
error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n"
logger.error(error_msg)
raise
if output is None:
return state
if isinstance(output, str):
return state.get(output)
elif isinstance(output, (list, tuple)):
return state.get(output)
else:
raise ValueError(f"Output '{output}' is not a valid output type")
...@@ -618,7 +618,6 @@ def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines ...@@ -618,7 +618,6 @@ def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines
def make_doc_string( def make_doc_string(
inputs, inputs,
intermediate_inputs,
outputs, outputs,
description="", description="",
class_name=None, class_name=None,
...@@ -664,7 +663,7 @@ def make_doc_string( ...@@ -664,7 +663,7 @@ def make_doc_string(
output += configs_str + "\n\n" output += configs_str + "\n\n"
# Add inputs section # Add inputs section
output += format_input_params(inputs + intermediate_inputs, indent_level=2) output += format_input_params(inputs, indent_level=2)
# Add outputs section # Add outputs section
output += "\n\n" output += "\n\n"
......
...@@ -27,7 +27,7 @@ from ...schedulers import EulerDiscreteScheduler ...@@ -27,7 +27,7 @@ from ...schedulers import EulerDiscreteScheduler
from ...utils import logging from ...utils import logging
from ...utils.torch_utils import randn_tensor, unwrap_module from ...utils.torch_utils import randn_tensor, unwrap_module
from ..modular_pipeline import ( from ..modular_pipeline import (
PipelineBlock, ModularPipelineBlocks,
PipelineState, PipelineState,
) )
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
...@@ -195,7 +195,7 @@ def prepare_latents_img2img( ...@@ -195,7 +195,7 @@ def prepare_latents_img2img(
return latents return latents
class StableDiffusionXLInputStep(PipelineBlock): class StableDiffusionXLInputStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl" model_name = "stable-diffusion-xl"
@property @property
...@@ -213,11 +213,6 @@ class StableDiffusionXLInputStep(PipelineBlock): ...@@ -213,11 +213,6 @@ class StableDiffusionXLInputStep(PipelineBlock):
def inputs(self) -> List[InputParam]: def inputs(self) -> List[InputParam]:
return [ return [
InputParam("num_images_per_prompt", default=1), InputParam("num_images_per_prompt", default=1),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam( InputParam(
"prompt_embeds", "prompt_embeds",
required=True, required=True,
...@@ -394,7 +389,7 @@ class StableDiffusionXLInputStep(PipelineBlock): ...@@ -394,7 +389,7 @@ class StableDiffusionXLInputStep(PipelineBlock):
return components, state return components, state
class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): class StableDiffusionXLImg2ImgSetTimestepsStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl" model_name = "stable-diffusion-xl"
@property @property
...@@ -421,11 +416,6 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): ...@@ -421,11 +416,6 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
InputParam("denoising_start"), InputParam("denoising_start"),
# YiYi TODO: do we need num_images_per_prompt here? # YiYi TODO: do we need num_images_per_prompt here?
InputParam("num_images_per_prompt", default=1), InputParam("num_images_per_prompt", default=1),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam( InputParam(
"batch_size", "batch_size",
required=True, required=True,
...@@ -543,7 +533,7 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): ...@@ -543,7 +533,7 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
return components, state return components, state
class StableDiffusionXLSetTimestepsStep(PipelineBlock): class StableDiffusionXLSetTimestepsStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl" model_name = "stable-diffusion-xl"
@property @property
...@@ -611,7 +601,7 @@ class StableDiffusionXLSetTimestepsStep(PipelineBlock): ...@@ -611,7 +601,7 @@ class StableDiffusionXLSetTimestepsStep(PipelineBlock):
return components, state return components, state
class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): class StableDiffusionXLInpaintPrepareLatentsStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl" model_name = "stable-diffusion-xl"
@property @property
...@@ -640,11 +630,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): ...@@ -640,11 +630,6 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
"`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of " "`num_inference_steps`. A value of 1, therefore, essentially ignores `image`. Note that in the case of "
"`denoising_start` being declared as an integer, the value of `strength` will be ignored.", "`denoising_start` being declared as an integer, the value of `strength` will be ignored.",
), ),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam("generator"), InputParam("generator"),
InputParam( InputParam(
"batch_size", "batch_size",
...@@ -890,7 +875,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): ...@@ -890,7 +875,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock):
return components, state return components, state
class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): class StableDiffusionXLImg2ImgPrepareLatentsStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl" model_name = "stable-diffusion-xl"
@property @property
...@@ -910,11 +895,6 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): ...@@ -910,11 +895,6 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
InputParam("latents"), InputParam("latents"),
InputParam("num_images_per_prompt", default=1), InputParam("num_images_per_prompt", default=1),
InputParam("denoising_start"), InputParam("denoising_start"),
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("generator"), InputParam("generator"),
InputParam( InputParam(
"latent_timestep", "latent_timestep",
...@@ -971,7 +951,7 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): ...@@ -971,7 +951,7 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock):
return components, state return components, state
class StableDiffusionXLPrepareLatentsStep(PipelineBlock): class StableDiffusionXLPrepareLatentsStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl" model_name = "stable-diffusion-xl"
@property @property
...@@ -992,11 +972,6 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock): ...@@ -992,11 +972,6 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
InputParam("width"), InputParam("width"),
InputParam("latents"), InputParam("latents"),
InputParam("num_images_per_prompt", default=1), InputParam("num_images_per_prompt", default=1),
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("generator"), InputParam("generator"),
InputParam( InputParam(
"batch_size", "batch_size",
...@@ -1082,7 +1057,7 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock): ...@@ -1082,7 +1057,7 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock):
return components, state return components, state
class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl" model_name = "stable-diffusion-xl"
@property @property
...@@ -1119,11 +1094,6 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): ...@@ -1119,11 +1094,6 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
InputParam("num_images_per_prompt", default=1), InputParam("num_images_per_prompt", default=1),
InputParam("aesthetic_score", default=6.0), InputParam("aesthetic_score", default=6.0),
InputParam("negative_aesthetic_score", default=2.0), InputParam("negative_aesthetic_score", default=2.0),
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam( InputParam(
"latents", "latents",
required=True, required=True,
...@@ -1306,7 +1276,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): ...@@ -1306,7 +1276,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock):
return components, state return components, state
class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): class StableDiffusionXLPrepareAdditionalConditioningStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl" model_name = "stable-diffusion-xl"
@property @property
...@@ -1335,11 +1305,6 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): ...@@ -1335,11 +1305,6 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
InputParam("crops_coords_top_left", default=(0, 0)), InputParam("crops_coords_top_left", default=(0, 0)),
InputParam("negative_crops_coords_top_left", default=(0, 0)), InputParam("negative_crops_coords_top_left", default=(0, 0)),
InputParam("num_images_per_prompt", default=1), InputParam("num_images_per_prompt", default=1),
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam( InputParam(
"latents", "latents",
required=True, required=True,
...@@ -1489,7 +1454,7 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): ...@@ -1489,7 +1454,7 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock):
return components, state return components, state
class StableDiffusionXLControlNetInputStep(PipelineBlock): class StableDiffusionXLControlNetInputStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl" model_name = "stable-diffusion-xl"
@property @property
...@@ -1517,11 +1482,6 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock): ...@@ -1517,11 +1482,6 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
InputParam("controlnet_conditioning_scale", default=1.0), InputParam("controlnet_conditioning_scale", default=1.0),
InputParam("guess_mode", default=False), InputParam("guess_mode", default=False),
InputParam("num_images_per_prompt", default=1), InputParam("num_images_per_prompt", default=1),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam( InputParam(
"latents", "latents",
required=True, required=True,
...@@ -1708,7 +1668,7 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock): ...@@ -1708,7 +1668,7 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock):
return components, state return components, state
class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): class StableDiffusionXLControlNetUnionInputStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl" model_name = "stable-diffusion-xl"
@property @property
...@@ -1737,11 +1697,6 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): ...@@ -1737,11 +1697,6 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock):
InputParam("controlnet_conditioning_scale", default=1.0), InputParam("controlnet_conditioning_scale", default=1.0),
InputParam("guess_mode", default=False), InputParam("guess_mode", default=False),
InputParam("num_images_per_prompt", default=1), InputParam("num_images_per_prompt", default=1),
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam( InputParam(
"latents", "latents",
required=True, required=True,
......
...@@ -24,7 +24,7 @@ from ...models import AutoencoderKL ...@@ -24,7 +24,7 @@ from ...models import AutoencoderKL
from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor from ...models.attention_processor import AttnProcessor2_0, XFormersAttnProcessor
from ...utils import logging from ...utils import logging
from ..modular_pipeline import ( from ..modular_pipeline import (
PipelineBlock, ModularPipelineBlocks,
PipelineState, PipelineState,
) )
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
...@@ -33,7 +33,7 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam ...@@ -33,7 +33,7 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class StableDiffusionXLDecodeStep(PipelineBlock): class StableDiffusionXLDecodeStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl" model_name = "stable-diffusion-xl"
@property @property
...@@ -56,17 +56,12 @@ class StableDiffusionXLDecodeStep(PipelineBlock): ...@@ -56,17 +56,12 @@ class StableDiffusionXLDecodeStep(PipelineBlock):
def inputs(self) -> List[Tuple[str, Any]]: def inputs(self) -> List[Tuple[str, Any]]:
return [ return [
InputParam("output_type", default="pil"), InputParam("output_type", default="pil"),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam( InputParam(
"latents", "latents",
required=True, required=True,
type_hint=torch.Tensor, type_hint=torch.Tensor,
description="The denoised latents from the denoising step", description="The denoised latents from the denoising step",
) ),
] ]
@property @property
...@@ -157,7 +152,7 @@ class StableDiffusionXLDecodeStep(PipelineBlock): ...@@ -157,7 +152,7 @@ class StableDiffusionXLDecodeStep(PipelineBlock):
return components, state return components, state
class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock): class StableDiffusionXLInpaintOverlayMaskStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl" model_name = "stable-diffusion-xl"
@property @property
...@@ -184,11 +179,6 @@ class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock): ...@@ -184,11 +179,6 @@ class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock):
InputParam("image"), InputParam("image"),
InputParam("mask_image"), InputParam("mask_image"),
InputParam("padding_mask_crop"), InputParam("padding_mask_crop"),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam( InputParam(
"images", "images",
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
......
...@@ -25,7 +25,7 @@ from ...utils import logging ...@@ -25,7 +25,7 @@ from ...utils import logging
from ..modular_pipeline import ( from ..modular_pipeline import (
BlockState, BlockState,
LoopSequentialPipelineBlocks, LoopSequentialPipelineBlocks,
PipelineBlock, ModularPipelineBlocks,
PipelineState, PipelineState,
) )
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
...@@ -37,7 +37,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -37,7 +37,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# YiYi experimenting composible denoise loop # YiYi experimenting composible denoise loop
# loop step (1): prepare latent input for denoiser # loop step (1): prepare latent input for denoiser
class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock): class StableDiffusionXLLoopBeforeDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl" model_name = "stable-diffusion-xl"
@property @property
...@@ -55,7 +55,7 @@ class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock): ...@@ -55,7 +55,7 @@ class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
) )
@property @property
def intermediate_inputs(self) -> List[str]: def inputs(self) -> List[str]:
return [ return [
InputParam( InputParam(
"latents", "latents",
...@@ -73,7 +73,7 @@ class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock): ...@@ -73,7 +73,7 @@ class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock):
# loop step (1): prepare latent input for denoiser (with inpainting) # loop step (1): prepare latent input for denoiser (with inpainting)
class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock): class StableDiffusionXLInpaintLoopBeforeDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl" model_name = "stable-diffusion-xl"
@property @property
...@@ -91,7 +91,7 @@ class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock): ...@@ -91,7 +91,7 @@ class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
) )
@property @property
def intermediate_inputs(self) -> List[str]: def inputs(self) -> List[str]:
return [ return [
InputParam( InputParam(
"latents", "latents",
...@@ -144,7 +144,7 @@ class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock): ...@@ -144,7 +144,7 @@ class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock):
# loop step (2): denoise the latents with guidance # loop step (2): denoise the latents with guidance
class StableDiffusionXLLoopDenoiser(PipelineBlock): class StableDiffusionXLLoopDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl" model_name = "stable-diffusion-xl"
@property @property
...@@ -171,11 +171,6 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock): ...@@ -171,11 +171,6 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock):
def inputs(self) -> List[Tuple[str, Any]]: def inputs(self) -> List[Tuple[str, Any]]:
return [ return [
InputParam("cross_attention_kwargs"), InputParam("cross_attention_kwargs"),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam( InputParam(
"num_inference_steps", "num_inference_steps",
required=True, required=True,
...@@ -249,7 +244,7 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock): ...@@ -249,7 +244,7 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock):
# loop step (2): denoise the latents with guidance (with controlnet) # loop step (2): denoise the latents with guidance (with controlnet)
class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock): class StableDiffusionXLControlNetLoopDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl" model_name = "stable-diffusion-xl"
@property @property
...@@ -277,11 +272,6 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock): ...@@ -277,11 +272,6 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
def inputs(self) -> List[Tuple[str, Any]]: def inputs(self) -> List[Tuple[str, Any]]:
return [ return [
InputParam("cross_attention_kwargs"), InputParam("cross_attention_kwargs"),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam( InputParam(
"controlnet_cond", "controlnet_cond",
required=True, required=True,
...@@ -449,7 +439,7 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock): ...@@ -449,7 +439,7 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock):
# loop step (3): scheduler step to update latents # loop step (3): scheduler step to update latents
class StableDiffusionXLLoopAfterDenoiser(PipelineBlock): class StableDiffusionXLLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl" model_name = "stable-diffusion-xl"
@property @property
...@@ -470,11 +460,6 @@ class StableDiffusionXLLoopAfterDenoiser(PipelineBlock): ...@@ -470,11 +460,6 @@ class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
def inputs(self) -> List[Tuple[str, Any]]: def inputs(self) -> List[Tuple[str, Any]]:
return [ return [
InputParam("eta", default=0.0), InputParam("eta", default=0.0),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam("generator"), InputParam("generator"),
] ]
...@@ -520,7 +505,7 @@ class StableDiffusionXLLoopAfterDenoiser(PipelineBlock): ...@@ -520,7 +505,7 @@ class StableDiffusionXLLoopAfterDenoiser(PipelineBlock):
# loop step (3): scheduler step to update latents (with inpainting) # loop step (3): scheduler step to update latents (with inpainting)
class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock): class StableDiffusionXLInpaintLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "stable-diffusion-xl" model_name = "stable-diffusion-xl"
@property @property
...@@ -542,11 +527,6 @@ class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock): ...@@ -542,11 +527,6 @@ class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock):
def inputs(self) -> List[Tuple[str, Any]]: def inputs(self) -> List[Tuple[str, Any]]:
return [ return [
InputParam("eta", default=0.0), InputParam("eta", default=0.0),
]
@property
def intermediate_inputs(self) -> List[str]:
return [
InputParam("generator"), InputParam("generator"),
InputParam( InputParam(
"timesteps", "timesteps",
...@@ -660,7 +640,7 @@ class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks): ...@@ -660,7 +640,7 @@ class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
] ]
@property @property
def loop_intermediate_inputs(self) -> List[InputParam]: def loop_inputs(self) -> List[InputParam]:
return [ return [
InputParam( InputParam(
"timesteps", "timesteps",
......
...@@ -35,7 +35,7 @@ from ...utils import ( ...@@ -35,7 +35,7 @@ from ...utils import (
scale_lora_layers, scale_lora_layers,
unscale_lora_layers, unscale_lora_layers,
) )
from ..modular_pipeline import PipelineBlock, PipelineState from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_pipeline import StableDiffusionXLModularPipeline from .modular_pipeline import StableDiffusionXLModularPipeline
...@@ -57,7 +57,7 @@ def retrieve_latents( ...@@ -57,7 +57,7 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output") raise AttributeError("Could not access latents of provided encoder_output")
class StableDiffusionXLIPAdapterStep(PipelineBlock): class StableDiffusionXLIPAdapterStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl" model_name = "stable-diffusion-xl"
@property @property
...@@ -215,7 +215,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock): ...@@ -215,7 +215,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock):
return components, state return components, state
class StableDiffusionXLTextEncoderStep(PipelineBlock): class StableDiffusionXLTextEncoderStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl" model_name = "stable-diffusion-xl"
@property @property
...@@ -576,7 +576,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock): ...@@ -576,7 +576,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock):
return components, state return components, state
class StableDiffusionXLVaeEncoderStep(PipelineBlock): class StableDiffusionXLVaeEncoderStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl" model_name = "stable-diffusion-xl"
@property @property
...@@ -601,11 +601,6 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock): ...@@ -601,11 +601,6 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
InputParam("image", required=True), InputParam("image", required=True),
InputParam("height"), InputParam("height"),
InputParam("width"), InputParam("width"),
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("generator"), InputParam("generator"),
InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
InputParam( InputParam(
...@@ -668,12 +663,11 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock): ...@@ -668,12 +663,11 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
block_state.device = components._execution_device block_state.device = components._execution_device
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
block_state.image = components.image_processor.preprocess( image = components.image_processor.preprocess(
block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs
) )
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) image = image.to(device=block_state.device, dtype=block_state.dtype)
block_state.batch_size = image.shape[0]
block_state.batch_size = block_state.image.shape[0]
# if generator is a list, make sure the length of it matches the length of images (both should be batch_size) # if generator is a list, make sure the length of it matches the length of images (both should be batch_size)
if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size: if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size:
...@@ -682,16 +676,14 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock): ...@@ -682,16 +676,14 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock):
f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators." f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators."
) )
block_state.image_latents = self._encode_vae_image( block_state.image_latents = self._encode_vae_image(components, image=image, generator=block_state.generator)
components, image=block_state.image, generator=block_state.generator
)
self.set_block_state(state, block_state) self.set_block_state(state, block_state)
return components, state return components, state
class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): class StableDiffusionXLInpaintVaeEncoderStep(ModularPipelineBlocks):
model_name = "stable-diffusion-xl" model_name = "stable-diffusion-xl"
@property @property
...@@ -726,11 +718,6 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): ...@@ -726,11 +718,6 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
InputParam("image", required=True), InputParam("image", required=True),
InputParam("mask_image", required=True), InputParam("mask_image", required=True),
InputParam("padding_mask_crop"), InputParam("padding_mask_crop"),
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"), InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
InputParam("generator"), InputParam("generator"),
] ]
...@@ -860,34 +847,32 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): ...@@ -860,34 +847,32 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock):
block_state.crops_coords = None block_state.crops_coords = None
block_state.resize_mode = "default" block_state.resize_mode = "default"
block_state.image = components.image_processor.preprocess( image = components.image_processor.preprocess(
block_state.image, block_state.image,
height=block_state.height, height=block_state.height,
width=block_state.width, width=block_state.width,
crops_coords=block_state.crops_coords, crops_coords=block_state.crops_coords,
resize_mode=block_state.resize_mode, resize_mode=block_state.resize_mode,
) )
block_state.image = block_state.image.to(dtype=torch.float32) image = image.to(dtype=torch.float32)
block_state.mask = components.mask_processor.preprocess( mask = components.mask_processor.preprocess(
block_state.mask_image, block_state.mask_image,
height=block_state.height, height=block_state.height,
width=block_state.width, width=block_state.width,
resize_mode=block_state.resize_mode, resize_mode=block_state.resize_mode,
crops_coords=block_state.crops_coords, crops_coords=block_state.crops_coords,
) )
block_state.masked_image = block_state.image * (block_state.mask < 0.5) block_state.masked_image = image * (mask < 0.5)
block_state.batch_size = block_state.image.shape[0] block_state.batch_size = image.shape[0]
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) image = image.to(device=block_state.device, dtype=block_state.dtype)
block_state.image_latents = self._encode_vae_image( block_state.image_latents = self._encode_vae_image(components, image=image, generator=block_state.generator)
components, image=block_state.image, generator=block_state.generator
)
# 7. Prepare mask latent variables # 7. Prepare mask latent variables
block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents( block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents(
components, components,
block_state.mask, mask,
block_state.masked_image, block_state.masked_image,
block_state.batch_size, block_state.batch_size,
block_state.height, block_state.height,
......
...@@ -247,10 +247,6 @@ SDXL_INPUTS_SCHEMA = { ...@@ -247,10 +247,6 @@ SDXL_INPUTS_SCHEMA = {
"control_mode": InputParam( "control_mode": InputParam(
"control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet" "control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet"
), ),
}
SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
"prompt_embeds": InputParam( "prompt_embeds": InputParam(
"prompt_embeds", "prompt_embeds",
type_hint=torch.Tensor, type_hint=torch.Tensor,
...@@ -271,13 +267,6 @@ SDXL_INTERMEDIATE_INPUTS_SCHEMA = { ...@@ -271,13 +267,6 @@ SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
"preprocess_kwargs": InputParam( "preprocess_kwargs": InputParam(
"preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor" "preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"
), ),
"latents": InputParam(
"latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"
),
"timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
"num_inference_steps": InputParam(
"num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"
),
"latent_timestep": InputParam( "latent_timestep": InputParam(
"latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep" "latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"
), ),
......
...@@ -20,7 +20,7 @@ import torch ...@@ -20,7 +20,7 @@ import torch
from ...schedulers import UniPCMultistepScheduler from ...schedulers import UniPCMultistepScheduler
from ...utils import logging from ...utils import logging
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
from ..modular_pipeline import PipelineBlock, PipelineState from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import WanModularPipeline from .modular_pipeline import WanModularPipeline
...@@ -94,7 +94,7 @@ def retrieve_timesteps( ...@@ -94,7 +94,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps return timesteps, num_inference_steps
class WanInputStep(PipelineBlock): class WanInputStep(ModularPipelineBlocks):
model_name = "wan" model_name = "wan"
@property @property
...@@ -194,7 +194,7 @@ class WanInputStep(PipelineBlock): ...@@ -194,7 +194,7 @@ class WanInputStep(PipelineBlock):
return components, state return components, state
class WanSetTimestepsStep(PipelineBlock): class WanSetTimestepsStep(ModularPipelineBlocks):
model_name = "wan" model_name = "wan"
@property @property
...@@ -243,7 +243,7 @@ class WanSetTimestepsStep(PipelineBlock): ...@@ -243,7 +243,7 @@ class WanSetTimestepsStep(PipelineBlock):
return components, state return components, state
class WanPrepareLatentsStep(PipelineBlock): class WanPrepareLatentsStep(ModularPipelineBlocks):
model_name = "wan" model_name = "wan"
@property @property
......
...@@ -22,14 +22,14 @@ from ...configuration_utils import FrozenDict ...@@ -22,14 +22,14 @@ from ...configuration_utils import FrozenDict
from ...models import AutoencoderKLWan from ...models import AutoencoderKLWan
from ...utils import logging from ...utils import logging
from ...video_processor import VideoProcessor from ...video_processor import VideoProcessor
from ..modular_pipeline import PipelineBlock, PipelineState from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class WanDecodeStep(PipelineBlock): class WanDecodeStep(ModularPipelineBlocks):
model_name = "wan" model_name = "wan"
@property @property
......
...@@ -24,7 +24,7 @@ from ...utils import logging ...@@ -24,7 +24,7 @@ from ...utils import logging
from ..modular_pipeline import ( from ..modular_pipeline import (
BlockState, BlockState,
LoopSequentialPipelineBlocks, LoopSequentialPipelineBlocks,
PipelineBlock, ModularPipelineBlocks,
PipelineState, PipelineState,
) )
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
...@@ -34,7 +34,7 @@ from .modular_pipeline import WanModularPipeline ...@@ -34,7 +34,7 @@ from .modular_pipeline import WanModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class WanLoopDenoiser(PipelineBlock): class WanLoopDenoiser(ModularPipelineBlocks):
model_name = "wan" model_name = "wan"
@property @property
...@@ -132,7 +132,7 @@ class WanLoopDenoiser(PipelineBlock): ...@@ -132,7 +132,7 @@ class WanLoopDenoiser(PipelineBlock):
return components, block_state return components, block_state
class WanLoopAfterDenoiser(PipelineBlock): class WanLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "wan" model_name = "wan"
@property @property
......
...@@ -22,7 +22,7 @@ from transformers import AutoTokenizer, UMT5EncoderModel ...@@ -22,7 +22,7 @@ from transformers import AutoTokenizer, UMT5EncoderModel
from ...configuration_utils import FrozenDict from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance from ...guiders import ClassifierFreeGuidance
from ...utils import is_ftfy_available, logging from ...utils import is_ftfy_available, logging
from ..modular_pipeline import PipelineBlock, PipelineState from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
from .modular_pipeline import WanModularPipeline from .modular_pipeline import WanModularPipeline
...@@ -51,7 +51,7 @@ def prompt_clean(text): ...@@ -51,7 +51,7 @@ def prompt_clean(text):
return text return text
class WanTextEncoderStep(PipelineBlock): class WanTextEncoderStep(ModularPipelineBlocks):
model_name = "wan" model_name = "wan"
@property @property
......
...@@ -117,13 +117,9 @@ class SDXLModularIPAdapterTests: ...@@ -117,13 +117,9 @@ class SDXLModularIPAdapterTests:
_ = blocks.sub_blocks.pop("ip_adapter") _ = blocks.sub_blocks.pop("ip_adapter")
parameters = blocks.input_names parameters = blocks.input_names
intermediate_parameters = blocks.intermediate_input_names
assert "ip_adapter_image" not in parameters, ( assert "ip_adapter_image" not in parameters, (
"`ip_adapter_image` argument must be removed from the `__call__` method" "`ip_adapter_image` argument must be removed from the `__call__` method"
) )
assert "ip_adapter_image_embeds" not in intermediate_parameters, (
"`ip_adapter_image_embeds` argument must be supported by the `__call__` method"
)
def _get_dummy_image_embeds(self, cross_attention_dim: int = 32): def _get_dummy_image_embeds(self, cross_attention_dim: int = 32):
return torch.randn((1, 1, cross_attention_dim), device=torch_device) return torch.randn((1, 1, cross_attention_dim), device=torch_device)
......
...@@ -139,7 +139,6 @@ class ModularPipelineTesterMixin: ...@@ -139,7 +139,6 @@ class ModularPipelineTesterMixin:
def test_pipeline_call_signature(self): def test_pipeline_call_signature(self):
pipe = self.get_pipeline() pipe = self.get_pipeline()
input_parameters = pipe.blocks.input_names input_parameters = pipe.blocks.input_names
intermediate_parameters = pipe.blocks.intermediate_input_names
optional_parameters = pipe.default_call_parameters optional_parameters = pipe.default_call_parameters
def _check_for_parameters(parameters, expected_parameters, param_type): def _check_for_parameters(parameters, expected_parameters, param_type):
...@@ -149,7 +148,6 @@ class ModularPipelineTesterMixin: ...@@ -149,7 +148,6 @@ class ModularPipelineTesterMixin:
) )
_check_for_parameters(self.params, input_parameters, "input") _check_for_parameters(self.params, input_parameters, "input")
_check_for_parameters(self.intermediate_params, intermediate_parameters, "intermediate")
_check_for_parameters(self.optional_params, optional_parameters, "optional") _check_for_parameters(self.optional_params, optional_parameters, "optional")
def test_inference_batch_consistent(self, batch_sizes=[2], batch_generator=True): def test_inference_batch_consistent(self, batch_sizes=[2], batch_generator=True):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment