Unverified Commit b455dc94 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

[modular] wan! (#12611)

* update, remove intermediaate_inputs

* support image2video

* revert dynamic steps to simplify

* refactor vae encoder block

* support flf2video!

* add support for wan2.2 14B

* style

* Apply suggestions from code review

* input dynamic step -> additiional input step

* up

* fix init

* update dtype
parent 04f9d2bf
...@@ -407,6 +407,7 @@ else: ...@@ -407,6 +407,7 @@ else:
"QwenImageModularPipeline", "QwenImageModularPipeline",
"StableDiffusionXLAutoBlocks", "StableDiffusionXLAutoBlocks",
"StableDiffusionXLModularPipeline", "StableDiffusionXLModularPipeline",
"Wan22AutoBlocks",
"WanAutoBlocks", "WanAutoBlocks",
"WanModularPipeline", "WanModularPipeline",
] ]
...@@ -1090,6 +1091,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -1090,6 +1091,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
QwenImageModularPipeline, QwenImageModularPipeline,
StableDiffusionXLAutoBlocks, StableDiffusionXLAutoBlocks,
StableDiffusionXLModularPipeline, StableDiffusionXLModularPipeline,
Wan22AutoBlocks,
WanAutoBlocks, WanAutoBlocks,
WanModularPipeline, WanModularPipeline,
) )
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch import torch
...@@ -88,6 +88,19 @@ class AdaptiveProjectedGuidance(BaseGuidance): ...@@ -88,6 +88,19 @@ class AdaptiveProjectedGuidance(BaseGuidance):
data_batches.append(data_batch) data_batches.append(data_batch)
return data_batches return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput: def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None pred = None
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch import torch
...@@ -99,6 +99,19 @@ class AdaptiveProjectedMixGuidance(BaseGuidance): ...@@ -99,6 +99,19 @@ class AdaptiveProjectedMixGuidance(BaseGuidance):
data_batches.append(data_batch) data_batches.append(data_batch)
return data_batches return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput: def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None pred = None
......
...@@ -141,6 +141,16 @@ class AutoGuidance(BaseGuidance): ...@@ -141,6 +141,16 @@ class AutoGuidance(BaseGuidance):
data_batches.append(data_batch) data_batches.append(data_batch)
return data_batches return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput: def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None pred = None
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch import torch
...@@ -99,6 +99,16 @@ class ClassifierFreeGuidance(BaseGuidance): ...@@ -99,6 +99,16 @@ class ClassifierFreeGuidance(BaseGuidance):
data_batches.append(data_batch) data_batches.append(data_batch)
return data_batches return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput: def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None pred = None
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch import torch
...@@ -85,6 +85,16 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance): ...@@ -85,6 +85,16 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
data_batches.append(data_batch) data_batches.append(data_batch)
return data_batches return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput: def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None pred = None
......
...@@ -226,6 +226,16 @@ class FrequencyDecoupledGuidance(BaseGuidance): ...@@ -226,6 +226,16 @@ class FrequencyDecoupledGuidance(BaseGuidance):
data_batches.append(data_batch) data_batches.append(data_batch)
return data_batches return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput: def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None pred = None
......
...@@ -166,6 +166,11 @@ class BaseGuidance(ConfigMixin, PushToHubMixin): ...@@ -166,6 +166,11 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.") raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
raise NotImplementedError("BaseGuidance::prepare_inputs_from_block_state must be implemented in subclasses.")
def __call__(self, data: List["BlockState"]) -> Any: def __call__(self, data: List["BlockState"]) -> Any:
if not all(hasattr(d, "noise_pred") for d in data): if not all(hasattr(d, "noise_pred") for d in data):
raise ValueError("Expected all data to have `noise_pred` attribute.") raise ValueError("Expected all data to have `noise_pred` attribute.")
...@@ -234,6 +239,51 @@ class BaseGuidance(ConfigMixin, PushToHubMixin): ...@@ -234,6 +239,51 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
data_batch[cls._identifier_key] = identifier data_batch[cls._identifier_key] = identifier
return BlockState(**data_batch) return BlockState(**data_batch)
@classmethod
def _prepare_batch_from_block_state(
cls,
input_fields: Dict[str, Union[str, Tuple[str, str]]],
data: "BlockState",
tuple_index: int,
identifier: str,
) -> "BlockState":
"""
Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of the
`BaseGuidance` class. It prepares the batch based on the provided tuple index.
Args:
input_fields (`Dict[str, Union[str, Tuple[str, str]]]`):
A dictionary where the keys are the names of the fields that will be used to store the data once it is
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
to look up the required data provided for preparation. If a string is provided, it will be used as the
conditional data (or unconditional if used with a guidance method that requires it). If a tuple of
length 2 is provided, the first element must be the conditional data identifier and the second element
must be the unconditional data identifier or None.
data (`BlockState`):
The input data to be prepared.
tuple_index (`int`):
The index to use when accessing input fields that are tuples.
Returns:
`BlockState`: The prepared batch of data.
"""
from ..modular_pipelines.modular_pipeline import BlockState
data_batch = {}
for key, value in input_fields.items():
try:
if isinstance(value, str):
data_batch[key] = getattr(data, value)
elif isinstance(value, tuple):
data_batch[key] = getattr(data, value[tuple_index])
else:
# We've already checked that value is a string or a tuple of strings with length 2
pass
except AttributeError:
logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
data_batch[cls._identifier_key] = identifier
return BlockState(**data_batch)
@classmethod @classmethod
@validate_hf_hub_args @validate_hf_hub_args
def from_pretrained( def from_pretrained(
......
...@@ -187,6 +187,26 @@ class PerturbedAttentionGuidance(BaseGuidance): ...@@ -187,6 +187,26 @@ class PerturbedAttentionGuidance(BaseGuidance):
data_batches.append(data_batch) data_batches.append(data_batch)
return data_batches return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
elif self.num_conditions == 2:
tuple_indices = [0, 1]
input_predictions = (
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
)
else:
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.forward # Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.forward
def forward( def forward(
self, self,
......
...@@ -183,6 +183,26 @@ class SkipLayerGuidance(BaseGuidance): ...@@ -183,6 +183,26 @@ class SkipLayerGuidance(BaseGuidance):
data_batches.append(data_batch) data_batches.append(data_batch)
return data_batches return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
elif self.num_conditions == 2:
tuple_indices = [0, 1]
input_predictions = (
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
)
else:
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward( def forward(
self, self,
pred_cond: torch.Tensor, pred_cond: torch.Tensor,
......
...@@ -172,6 +172,26 @@ class SmoothedEnergyGuidance(BaseGuidance): ...@@ -172,6 +172,26 @@ class SmoothedEnergyGuidance(BaseGuidance):
data_batches.append(data_batch) data_batches.append(data_batch)
return data_batches return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
elif self.num_conditions == 2:
tuple_indices = [0, 1]
input_predictions = (
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"]
)
else:
tuple_indices = [0, 1, 0]
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward( def forward(
self, self,
pred_cond: torch.Tensor, pred_cond: torch.Tensor,
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import math import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch import torch
...@@ -74,6 +74,16 @@ class TangentialClassifierFreeGuidance(BaseGuidance): ...@@ -74,6 +74,16 @@ class TangentialClassifierFreeGuidance(BaseGuidance):
data_batches.append(data_batch) data_batches.append(data_batch)
return data_batches return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput: def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None pred = None
......
...@@ -45,7 +45,7 @@ else: ...@@ -45,7 +45,7 @@ else:
"InsertableDict", "InsertableDict",
] ]
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"] _import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
_import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"] _import_structure["wan"] = ["WanAutoBlocks", "Wan22AutoBlocks", "WanModularPipeline"]
_import_structure["flux"] = [ _import_structure["flux"] = [
"FluxAutoBlocks", "FluxAutoBlocks",
"FluxModularPipeline", "FluxModularPipeline",
...@@ -90,7 +90,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -90,7 +90,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
QwenImageModularPipeline, QwenImageModularPipeline,
) )
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
from .wan import WanAutoBlocks, WanModularPipeline from .wan import Wan22AutoBlocks, WanAutoBlocks, WanModularPipeline
else: else:
import sys import sys
......
...@@ -1441,6 +1441,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin): ...@@ -1441,6 +1441,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
components_manager: Optional[ComponentsManager] = None, components_manager: Optional[ComponentsManager] = None,
collection: Optional[str] = None, collection: Optional[str] = None,
modular_config_dict: Optional[Dict[str, Any]] = None,
config_dict: Optional[Dict[str, Any]] = None,
**kwargs, **kwargs,
): ):
""" """
...@@ -1492,23 +1494,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin): ...@@ -1492,23 +1494,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
- The pipeline's config dict is also used to store the pipeline blocks's class name, which will be saved as - The pipeline's config dict is also used to store the pipeline blocks's class name, which will be saved as
`_blocks_class_name` in the config dict `_blocks_class_name` in the config dict
""" """
if blocks is None:
blocks_class_name = self.default_blocks_name
if blocks_class_name is not None:
diffusers_module = importlib.import_module("diffusers")
blocks_class = getattr(diffusers_module, blocks_class_name)
blocks = blocks_class()
else:
logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}")
self.blocks = blocks if modular_config_dict is None and config_dict is None and pretrained_model_name_or_path is not None:
self._components_manager = components_manager
self._collection = collection
self._component_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_components}
self._config_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_configs}
# update component_specs and config_specs from modular_repo
if pretrained_model_name_or_path is not None:
cache_dir = kwargs.pop("cache_dir", None) cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None) proxies = kwargs.pop("proxies", None)
...@@ -1524,52 +1511,59 @@ class ModularPipeline(ConfigMixin, PushToHubMixin): ...@@ -1524,52 +1511,59 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
"local_files_only": local_files_only, "local_files_only": local_files_only,
"revision": revision, "revision": revision,
} }
# try to load modular_model_index.json
try: modular_config_dict, config_dict = self._load_pipeline_config(
config_dict = self.load_config(pretrained_model_name_or_path, **load_config_kwargs) pretrained_model_name_or_path, **load_config_kwargs
except EnvironmentError as e: )
logger.debug(f"modular_model_index.json not found: {e}")
config_dict = None if blocks is None:
if modular_config_dict is not None:
# update component_specs and config_specs based on modular_model_index.json blocks_class_name = modular_config_dict.get("_blocks_class_name")
if config_dict is not None: elif config_dict is not None:
for name, value in config_dict.items(): blocks_class_name = self.get_default_blocks_name(config_dict)
# all the components in modular_model_index.json are from_pretrained components else:
if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 3: blocks_class_name = None
library, class_name, component_spec_dict = value if blocks_class_name is not None:
component_spec = self._dict_to_component_spec(name, component_spec_dict) diffusers_module = importlib.import_module("diffusers")
component_spec.default_creation_method = "from_pretrained" blocks_class = getattr(diffusers_module, blocks_class_name)
self._component_specs[name] = component_spec blocks = blocks_class()
elif name in self._config_specs:
self._config_specs[name].default = value
# if modular_model_index.json is not found, try to load model_index.json
else: else:
logger.debug(" loading config from model_index.json") logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}")
try:
from diffusers import DiffusionPipeline self.blocks = blocks
self._components_manager = components_manager
config_dict = DiffusionPipeline.load_config(pretrained_model_name_or_path, **load_config_kwargs) self._collection = collection
except EnvironmentError as e: self._component_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_components}
logger.debug(f" model_index.json not found in the repo: {e}") self._config_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_configs}
config_dict = None
# update component_specs and config_specs based on modular_model_index.json
# update component_specs and config_specs based on model_index.json if modular_config_dict is not None:
if config_dict is not None: for name, value in modular_config_dict.items():
for name, value in config_dict.items(): # all the components in modular_model_index.json are from_pretrained components
if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 2: if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 3:
library, class_name = value library, class_name, component_spec_dict = value
component_spec_dict = { component_spec = self._dict_to_component_spec(name, component_spec_dict)
"repo": pretrained_model_name_or_path, component_spec.default_creation_method = "from_pretrained"
"subfolder": name, self._component_specs[name] = component_spec
"type_hint": (library, class_name),
} elif name in self._config_specs:
component_spec = self._dict_to_component_spec(name, component_spec_dict) self._config_specs[name].default = value
component_spec.default_creation_method = "from_pretrained"
self._component_specs[name] = component_spec # if `modular_config_dict` is None (i.e. `modular_model_index.json` is not found), update based on `config_dict` (i.e. `model_index.json`)
elif name in self._config_specs: elif config_dict is not None:
self._config_specs[name].default = value for name, value in config_dict.items():
if name in self._component_specs and isinstance(value, (tuple, list)) and len(value) == 2:
library, class_name = value
component_spec_dict = {
"repo": pretrained_model_name_or_path,
"subfolder": name,
"type_hint": (library, class_name),
}
component_spec = self._dict_to_component_spec(name, component_spec_dict)
component_spec.default_creation_method = "from_pretrained"
self._component_specs[name] = component_spec
elif name in self._config_specs:
self._config_specs[name].default = value
if len(kwargs) > 0: if len(kwargs) > 0:
logger.warning(f"Unexpected input '{kwargs.keys()}' provided. This input will be ignored.") logger.warning(f"Unexpected input '{kwargs.keys()}' provided. This input will be ignored.")
...@@ -1601,6 +1595,35 @@ class ModularPipeline(ConfigMixin, PushToHubMixin): ...@@ -1601,6 +1595,35 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
params[input_param.name] = input_param.default params[input_param.name] = input_param.default
return params return params
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
return self.default_blocks_name
@classmethod
def _load_pipeline_config(
cls,
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
**load_config_kwargs,
):
try:
# try to load modular_model_index.json
modular_config_dict = cls.load_config(pretrained_model_name_or_path, **load_config_kwargs)
return modular_config_dict, None
except EnvironmentError as e:
logger.debug(f" modular_model_index.json not found in the repo: {e}")
try:
logger.debug(" try to load model_index.json")
from diffusers import DiffusionPipeline
config_dict = DiffusionPipeline.load_config(pretrained_model_name_or_path, **load_config_kwargs)
return None, config_dict
except EnvironmentError as e:
logger.debug(f" model_index.json not found in the repo: {e}")
return None, None
@classmethod @classmethod
@validate_hf_hub_args @validate_hf_hub_args
def from_pretrained( def from_pretrained(
...@@ -1655,42 +1678,33 @@ class ModularPipeline(ConfigMixin, PushToHubMixin): ...@@ -1655,42 +1678,33 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
"revision": revision, "revision": revision,
} }
try: modular_config_dict, config_dict = cls._load_pipeline_config(
# try to load modular_model_index.json pretrained_model_name_or_path, **load_config_kwargs
config_dict = cls.load_config(pretrained_model_name_or_path, **load_config_kwargs) )
except EnvironmentError as e:
logger.debug(f" modular_model_index.json not found in the repo: {e}")
config_dict = None
if config_dict is not None: if modular_config_dict is not None:
pipeline_class = _get_pipeline_class(cls, config=config_dict) pipeline_class = _get_pipeline_class(cls, config=modular_config_dict)
elif config_dict is not None:
from diffusers.pipelines.auto_pipeline import _get_model
logger.debug(" try to determine the modular pipeline class from model_index.json")
standard_pipeline_class = _get_pipeline_class(cls, config=config_dict)
model_name = _get_model(standard_pipeline_class.__name__)
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(model_name, ModularPipeline.__name__)
diffusers_module = importlib.import_module("diffusers")
pipeline_class = getattr(diffusers_module, pipeline_class_name)
else: else:
try: # there is no config for modular pipeline, assuming that the pipeline block does not need any from_pretrained components
logger.debug(" try to load model_index.json") pipeline_class = cls
from diffusers import DiffusionPipeline pretrained_model_name_or_path = None
from diffusers.pipelines.auto_pipeline import _get_model
config_dict = DiffusionPipeline.load_config(pretrained_model_name_or_path, **load_config_kwargs)
except EnvironmentError as e:
logger.debug(f" model_index.json not found in the repo: {e}")
if config_dict is not None:
logger.debug(" try to determine the modular pipeline class from model_index.json")
standard_pipeline_class = _get_pipeline_class(cls, config=config_dict)
model_name = _get_model(standard_pipeline_class.__name__)
pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(model_name, ModularPipeline.__name__)
diffusers_module = importlib.import_module("diffusers")
pipeline_class = getattr(diffusers_module, pipeline_class_name)
else:
# there is no config for modular pipeline, assuming that the pipeline block does not need any from_pretrained components
pipeline_class = cls
pretrained_model_name_or_path = None
pipeline = pipeline_class( pipeline = pipeline_class(
blocks=blocks, blocks=blocks,
pretrained_model_name_or_path=pretrained_model_name_or_path, pretrained_model_name_or_path=pretrained_model_name_or_path,
components_manager=components_manager, components_manager=components_manager,
collection=collection, collection=collection,
modular_config_dict=modular_config_dict,
config_dict=config_dict,
**kwargs, **kwargs,
) )
return pipeline return pipeline
...@@ -2134,7 +2148,9 @@ class ModularPipeline(ConfigMixin, PushToHubMixin): ...@@ -2134,7 +2148,9 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
logger.warning( logger.warning(
f"\nFailed to create component {name}:\n" f"\nFailed to create component {name}:\n"
f"- Component spec: {spec}\n" f"- Component spec: {spec}\n"
f"- load() called with kwargs: {component_load_kwargs}\n\n" f"- load() called with kwargs: {component_load_kwargs}\n"
"If this component is not required for your workflow you can safely ignore this message.\n\n"
"Traceback:\n"
f"{traceback.format_exc()}" f"{traceback.format_exc()}"
) )
......
...@@ -21,16 +21,14 @@ except OptionalDependencyNotAvailable: ...@@ -21,16 +21,14 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else: else:
_import_structure["decoders"] = ["WanImageVaeDecoderStep"]
_import_structure["encoders"] = ["WanTextEncoderStep"] _import_structure["encoders"] = ["WanTextEncoderStep"]
_import_structure["modular_blocks"] = [ _import_structure["modular_blocks"] = [
"ALL_BLOCKS", "ALL_BLOCKS",
"AUTO_BLOCKS", "Wan22AutoBlocks",
"TEXT2VIDEO_BLOCKS",
"WanAutoBeforeDenoiseStep",
"WanAutoBlocks", "WanAutoBlocks",
"WanAutoBlocks", "WanAutoImageEncoderStep",
"WanAutoDecodeStep", "WanAutoVaeImageEncoderStep",
"WanAutoDenoiseStep",
] ]
_import_structure["modular_pipeline"] = ["WanModularPipeline"] _import_structure["modular_pipeline"] = ["WanModularPipeline"]
...@@ -41,15 +39,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ...@@ -41,15 +39,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403 from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else: else:
from .decoders import WanImageVaeDecoderStep
from .encoders import WanTextEncoderStep from .encoders import WanTextEncoderStep
from .modular_blocks import ( from .modular_blocks import (
ALL_BLOCKS, ALL_BLOCKS,
AUTO_BLOCKS, Wan22AutoBlocks,
TEXT2VIDEO_BLOCKS,
WanAutoBeforeDenoiseStep,
WanAutoBlocks, WanAutoBlocks,
WanAutoDecodeStep, WanAutoImageEncoderStep,
WanAutoDenoiseStep, WanAutoVaeImageEncoderStep,
) )
from .modular_pipeline import WanModularPipeline from .modular_pipeline import WanModularPipeline
else: else:
......
...@@ -13,10 +13,11 @@ ...@@ -13,10 +13,11 @@
# limitations under the License. # limitations under the License.
import inspect import inspect
from typing import List, Optional, Union from typing import List, Optional, Tuple, Union
import torch import torch
from ...models import WanTransformer3DModel
from ...schedulers import UniPCMultistepScheduler from ...schedulers import UniPCMultistepScheduler
from ...utils import logging from ...utils import logging
from ...utils.torch_utils import randn_tensor from ...utils.torch_utils import randn_tensor
...@@ -34,6 +35,97 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name ...@@ -34,6 +35,97 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# configuration of guider is. # configuration of guider is.
def repeat_tensor_to_batch_size(
input_name: str,
input_tensor: torch.Tensor,
batch_size: int,
num_videos_per_prompt: int = 1,
) -> torch.Tensor:
"""Repeat tensor elements to match the final batch size.
This function expands a tensor's batch dimension to match the final batch size (batch_size * num_videos_per_prompt)
by repeating each element along dimension 0.
The input tensor must have batch size 1 or batch_size. The function will:
- If batch size is 1: repeat each element (batch_size * num_videos_per_prompt) times
- If batch size equals batch_size: repeat each element num_videos_per_prompt times
Args:
input_name (str): Name of the input tensor (used for error messages)
input_tensor (torch.Tensor): The tensor to repeat. Must have batch size 1 or batch_size.
batch_size (int): The base batch size (number of prompts)
num_videos_per_prompt (int, optional): Number of videos to generate per prompt. Defaults to 1.
Returns:
torch.Tensor: The repeated tensor with final batch size (batch_size * num_videos_per_prompt)
Raises:
ValueError: If input_tensor is not a torch.Tensor or has invalid batch size
Examples:
tensor = torch.tensor([[1, 2, 3]]) # shape: [1, 3] repeated = repeat_tensor_to_batch_size("image", tensor,
batch_size=2, num_videos_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) - shape:
[4, 3]
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: [2, 3] repeated = repeat_tensor_to_batch_size("image",
tensor, batch_size=2, num_videos_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]])
- shape: [4, 3]
"""
# make sure input is a tensor
if not isinstance(input_tensor, torch.Tensor):
raise ValueError(f"`{input_name}` must be a tensor")
# make sure input tensor e.g. image_latents has batch size 1 or batch_size same as prompts
if input_tensor.shape[0] == 1:
repeat_by = batch_size * num_videos_per_prompt
elif input_tensor.shape[0] == batch_size:
repeat_by = num_videos_per_prompt
else:
raise ValueError(
f"`{input_name}` must have have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}"
)
# expand the tensor to match the batch_size * num_videos_per_prompt
input_tensor = input_tensor.repeat_interleave(repeat_by, dim=0)
return input_tensor
def calculate_dimension_from_latents(
latents: torch.Tensor, vae_scale_factor_temporal: int, vae_scale_factor_spatial: int
) -> Tuple[int, int]:
"""Calculate image dimensions from latent tensor dimensions.
This function converts latent temporal and spatial dimensions to image temporal and spatial dimensions by
multiplying the latent num_frames/height/width by the VAE scale factor.
Args:
latents (torch.Tensor): The latent tensor. Must have 4 or 5 dimensions.
Expected shapes: [batch, channels, height, width] or [batch, channels, frames, height, width]
vae_scale_factor_temporal (int): The scale factor used by the VAE to compress temporal dimension.
Typically 4 for most VAEs (video is 4x larger than latents in temporal dimension)
vae_scale_factor_spatial (int): The scale factor used by the VAE to compress spatial dimension.
Typically 8 for most VAEs (image is 8x larger than latents in each dimension)
Returns:
Tuple[int, int]: The calculated image dimensions as (height, width)
Raises:
ValueError: If latents tensor doesn't have 4 or 5 dimensions
"""
if latents.ndim != 5:
raise ValueError(f"latents must have 5 dimensions, but got {latents.ndim}")
_, _, num_latent_frames, latent_height, latent_width = latents.shape
num_frames = (num_latent_frames - 1) * vae_scale_factor_temporal + 1
height = latent_height * vae_scale_factor_spatial
width = latent_width * vae_scale_factor_spatial
return num_frames, height, width
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps( def retrieve_timesteps(
scheduler, scheduler,
...@@ -94,7 +186,7 @@ def retrieve_timesteps( ...@@ -94,7 +186,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps return timesteps, num_inference_steps
class WanInputStep(ModularPipelineBlocks): class WanTextInputStep(ModularPipelineBlocks):
model_name = "wan" model_name = "wan"
@property @property
...@@ -109,14 +201,15 @@ class WanInputStep(ModularPipelineBlocks): ...@@ -109,14 +201,15 @@ class WanInputStep(ModularPipelineBlocks):
) )
@property @property
def inputs(self) -> List[InputParam]: def expected_components(self) -> List[ComponentSpec]:
return [ return [
InputParam("num_videos_per_prompt", default=1), ComponentSpec("transformer", WanTransformer3DModel),
] ]
@property @property
def intermediate_inputs(self) -> List[str]: def inputs(self) -> List[InputParam]:
return [ return [
InputParam("num_videos_per_prompt", default=1),
InputParam( InputParam(
"prompt_embeds", "prompt_embeds",
required=True, required=True,
...@@ -141,19 +234,7 @@ class WanInputStep(ModularPipelineBlocks): ...@@ -141,19 +234,7 @@ class WanInputStep(ModularPipelineBlocks):
OutputParam( OutputParam(
"dtype", "dtype",
type_hint=torch.dtype, type_hint=torch.dtype,
description="Data type of model tensor inputs (determined by `prompt_embeds`)", description="Data type of model tensor inputs (determined by `transformer.dtype`)",
),
OutputParam(
"prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
description="text embeddings used to guide the image generation",
),
OutputParam(
"negative_prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields", # already in intermedites state but declare here again for denoiser_input_fields
description="negative text embeddings used to guide the image generation",
), ),
] ]
...@@ -194,6 +275,140 @@ class WanInputStep(ModularPipelineBlocks): ...@@ -194,6 +275,140 @@ class WanInputStep(ModularPipelineBlocks):
return components, state return components, state
class WanAdditionalInputsStep(ModularPipelineBlocks):
model_name = "wan"
def __init__(
self,
image_latent_inputs: List[str] = ["first_frame_latents"],
additional_batch_inputs: List[str] = [],
):
"""Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
This step handles multiple common tasks to prepare inputs for the denoising step:
1. For encoded image latents, use it update height/width if None, and expands batch size
2. For additional_batch_inputs: Only expands batch dimensions to match final batch size
This is a dynamic block that allows you to configure which inputs to process.
Args:
image_latent_inputs (List[str], optional): Names of image latent tensors to process.
In additional to adjust batch size of these inputs, they will be used to determine height/width. Can be
a single string or list of strings. Defaults to ["first_frame_latents"].
additional_batch_inputs (List[str], optional):
Names of additional conditional input tensors to expand batch size. These tensors will only have their
batch dimensions adjusted to match the final batch size. Can be a single string or list of strings.
Defaults to [].
Examples:
# Configure to process first_frame_latents (default behavior) WanAdditionalInputsStep()
# Configure to process multiple image latent inputs
WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents", "last_frame_latents"])
# Configure to process image latents and additional batch inputs WanAdditionalInputsStep(
image_latent_inputs=["first_frame_latents"], additional_batch_inputs=["image_embeds"]
)
"""
if not isinstance(image_latent_inputs, list):
image_latent_inputs = [image_latent_inputs]
if not isinstance(additional_batch_inputs, list):
additional_batch_inputs = [additional_batch_inputs]
self._image_latent_inputs = image_latent_inputs
self._additional_batch_inputs = additional_batch_inputs
super().__init__()
@property
def description(self) -> str:
# Functionality section
summary_section = (
"Input processing step that:\n"
" 1. For image latent inputs: Updates height/width if None, and expands batch size\n"
" 2. For additional batch inputs: Expands batch dimensions to match final batch size"
)
# Inputs info
inputs_info = ""
if self._image_latent_inputs or self._additional_batch_inputs:
inputs_info = "\n\nConfigured inputs:"
if self._image_latent_inputs:
inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
if self._additional_batch_inputs:
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
# Placement guidance
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
return summary_section + inputs_info + placement_section
@property
def inputs(self) -> List[InputParam]:
inputs = [
InputParam(name="num_videos_per_prompt", default=1),
InputParam(name="batch_size", required=True),
InputParam(name="height"),
InputParam(name="width"),
InputParam(name="num_frames"),
]
# Add image latent inputs
for image_latent_input_name in self._image_latent_inputs:
inputs.append(InputParam(name=image_latent_input_name))
# Add additional batch inputs
for input_name in self._additional_batch_inputs:
inputs.append(InputParam(name=input_name))
return inputs
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
# Process image latent inputs (height/width calculation, patchify, and batch expansion)
for image_latent_input_name in self._image_latent_inputs:
image_latent_tensor = getattr(block_state, image_latent_input_name)
if image_latent_tensor is None:
continue
# 1. Calculate num_frames, height/width from latents
num_frames, height, width = calculate_dimension_from_latents(
image_latent_tensor, components.vae_scale_factor_temporal, components.vae_scale_factor_spatial
)
block_state.num_frames = block_state.num_frames or num_frames
block_state.height = block_state.height or height
block_state.width = block_state.width or width
# 3. Expand batch size
image_latent_tensor = repeat_tensor_to_batch_size(
input_name=image_latent_input_name,
input_tensor=image_latent_tensor,
num_videos_per_prompt=block_state.num_videos_per_prompt,
batch_size=block_state.batch_size,
)
setattr(block_state, image_latent_input_name, image_latent_tensor)
# Process additional batch inputs (only batch expansion)
for input_name in self._additional_batch_inputs:
input_tensor = getattr(block_state, input_name)
if input_tensor is None:
continue
# Only expand batch size
input_tensor = repeat_tensor_to_batch_size(
input_name=input_name,
input_tensor=input_tensor,
num_videos_per_prompt=block_state.num_videos_per_prompt,
batch_size=block_state.batch_size,
)
setattr(block_state, input_name, input_tensor)
self.set_block_state(state, block_state)
return components, state
class WanSetTimestepsStep(ModularPipelineBlocks): class WanSetTimestepsStep(ModularPipelineBlocks):
model_name = "wan" model_name = "wan"
...@@ -215,26 +430,15 @@ class WanSetTimestepsStep(ModularPipelineBlocks): ...@@ -215,26 +430,15 @@ class WanSetTimestepsStep(ModularPipelineBlocks):
InputParam("sigmas"), InputParam("sigmas"),
] ]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
OutputParam(
"num_inference_steps",
type_hint=int,
description="The number of denoising steps to perform at inference time",
),
]
@torch.no_grad() @torch.no_grad()
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state) block_state = self.get_block_state(state)
block_state.device = components._execution_device device = components._execution_device
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps( block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
components.scheduler, components.scheduler,
block_state.num_inference_steps, block_state.num_inference_steps,
block_state.device, device,
block_state.timesteps, block_state.timesteps,
block_state.sigmas, block_state.sigmas,
) )
...@@ -246,10 +450,6 @@ class WanSetTimestepsStep(ModularPipelineBlocks): ...@@ -246,10 +450,6 @@ class WanSetTimestepsStep(ModularPipelineBlocks):
class WanPrepareLatentsStep(ModularPipelineBlocks): class WanPrepareLatentsStep(ModularPipelineBlocks):
model_name = "wan" model_name = "wan"
@property
def expected_components(self) -> List[ComponentSpec]:
return []
@property @property
def description(self) -> str: def description(self) -> str:
return "Prepare latents step that prepares the latents for the text-to-video generation process" return "Prepare latents step that prepares the latents for the text-to-video generation process"
...@@ -262,11 +462,6 @@ class WanPrepareLatentsStep(ModularPipelineBlocks): ...@@ -262,11 +462,6 @@ class WanPrepareLatentsStep(ModularPipelineBlocks):
InputParam("num_frames", type_hint=int), InputParam("num_frames", type_hint=int),
InputParam("latents", type_hint=Optional[torch.Tensor]), InputParam("latents", type_hint=Optional[torch.Tensor]),
InputParam("num_videos_per_prompt", type_hint=int, default=1), InputParam("num_videos_per_prompt", type_hint=int, default=1),
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return [
InputParam("generator"), InputParam("generator"),
InputParam( InputParam(
"batch_size", "batch_size",
...@@ -337,29 +532,106 @@ class WanPrepareLatentsStep(ModularPipelineBlocks): ...@@ -337,29 +532,106 @@ class WanPrepareLatentsStep(ModularPipelineBlocks):
@torch.no_grad() @torch.no_grad()
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state) block_state = self.get_block_state(state)
self.check_inputs(components, block_state)
device = components._execution_device
dtype = torch.float32 # Wan latents should be torch.float32 for best quality
block_state.height = block_state.height or components.default_height block_state.height = block_state.height or components.default_height
block_state.width = block_state.width or components.default_width block_state.width = block_state.width or components.default_width
block_state.num_frames = block_state.num_frames or components.default_num_frames block_state.num_frames = block_state.num_frames or components.default_num_frames
block_state.device = components._execution_device
block_state.dtype = torch.float32 # Wan latents should be torch.float32 for best quality
block_state.num_channels_latents = components.num_channels_latents
self.check_inputs(components, block_state)
block_state.latents = self.prepare_latents( block_state.latents = self.prepare_latents(
components, components,
block_state.batch_size * block_state.num_videos_per_prompt, batch_size=block_state.batch_size * block_state.num_videos_per_prompt,
block_state.num_channels_latents, num_channels_latents=components.num_channels_latents,
block_state.height, height=block_state.height,
block_state.width, width=block_state.width,
block_state.num_frames, num_frames=block_state.num_frames,
block_state.dtype, dtype=dtype,
block_state.device, device=device,
block_state.generator, generator=block_state.generator,
block_state.latents, latents=block_state.latents,
) )
self.set_block_state(state, block_state) self.set_block_state(state, block_state)
return components, state return components, state
class WanPrepareFirstFrameLatentsStep(ModularPipelineBlocks):
model_name = "wan"
@property
def description(self) -> str:
return "step that prepares the masked first frame latents and add it to the latent condition"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("first_frame_latents", type_hint=Optional[torch.Tensor]),
InputParam("num_frames", type_hint=int),
]
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
batch_size, _, _, latent_height, latent_width = block_state.first_frame_latents.shape
mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
mask_lat_size[:, :, list(range(1, block_state.num_frames))] = 0
first_frame_mask = mask_lat_size[:, :, 0:1]
first_frame_mask = torch.repeat_interleave(
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
)
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
mask_lat_size = mask_lat_size.view(
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
)
mask_lat_size = mask_lat_size.transpose(1, 2)
mask_lat_size = mask_lat_size.to(block_state.first_frame_latents.device)
block_state.first_frame_latents = torch.concat([mask_lat_size, block_state.first_frame_latents], dim=1)
self.set_block_state(state, block_state)
return components, state
class WanPrepareFirstLastFrameLatentsStep(ModularPipelineBlocks):
model_name = "wan"
@property
def description(self) -> str:
return "step that prepares the masked latents with first and last frames and add it to the latent condition"
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("first_last_frame_latents", type_hint=Optional[torch.Tensor]),
InputParam("num_frames", type_hint=int),
]
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
batch_size, _, _, latent_height, latent_width = block_state.first_last_frame_latents.shape
mask_lat_size = torch.ones(batch_size, 1, block_state.num_frames, latent_height, latent_width)
mask_lat_size[:, :, list(range(1, block_state.num_frames - 1))] = 0
first_frame_mask = mask_lat_size[:, :, 0:1]
first_frame_mask = torch.repeat_interleave(
first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal
)
mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
mask_lat_size = mask_lat_size.view(
batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width
)
mask_lat_size = mask_lat_size.transpose(1, 2)
mask_lat_size = mask_lat_size.to(block_state.first_last_frame_latents.device)
block_state.first_last_frame_latents = torch.concat(
[mask_lat_size, block_state.first_last_frame_latents], dim=1
)
self.set_block_state(state, block_state)
return components, state
...@@ -29,7 +29,7 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam ...@@ -29,7 +29,7 @@ from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class WanDecodeStep(ModularPipelineBlocks): class WanImageVaeDecoderStep(ModularPipelineBlocks):
model_name = "wan" model_name = "wan"
@property @property
...@@ -50,12 +50,6 @@ class WanDecodeStep(ModularPipelineBlocks): ...@@ -50,12 +50,6 @@ class WanDecodeStep(ModularPipelineBlocks):
@property @property
def inputs(self) -> List[Tuple[str, Any]]: def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("output_type", default="pil"),
]
@property
def intermediate_inputs(self) -> List[str]:
return [ return [
InputParam( InputParam(
"latents", "latents",
...@@ -80,25 +74,20 @@ class WanDecodeStep(ModularPipelineBlocks): ...@@ -80,25 +74,20 @@ class WanDecodeStep(ModularPipelineBlocks):
block_state = self.get_block_state(state) block_state = self.get_block_state(state)
vae_dtype = components.vae.dtype vae_dtype = components.vae.dtype
if not block_state.output_type == "latent": latents = block_state.latents
latents = block_state.latents latents_mean = (
latents_mean = ( torch.tensor(components.vae.config.latents_mean)
torch.tensor(components.vae.config.latents_mean) .view(1, components.vae.config.z_dim, 1, 1, 1)
.view(1, components.vae.config.z_dim, 1, 1, 1) .to(latents.device, latents.dtype)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
1, components.vae.config.z_dim, 1, 1, 1
).to(latents.device, latents.dtype)
latents = latents / latents_std + latents_mean
latents = latents.to(vae_dtype)
block_state.videos = components.vae.decode(latents, return_dict=False)[0]
else:
block_state.videos = block_state.latents
block_state.videos = components.video_processor.postprocess_video(
block_state.videos, output_type=block_state.output_type
) )
latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
1, components.vae.config.z_dim, 1, 1, 1
).to(latents.device, latents.dtype)
latents = latents / latents_std + latents_mean
latents = latents.to(vae_dtype)
block_state.videos = components.vae.decode(latents, return_dict=False)[0]
block_state.videos = components.video_processor.postprocess_video(block_state.videos, output_type="np")
self.set_block_state(state, block_state) self.set_block_state(state, block_state)
......
...@@ -16,96 +16,244 @@ from ...utils import logging ...@@ -16,96 +16,244 @@ from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict from ..modular_pipeline_utils import InsertableDict
from .before_denoise import ( from .before_denoise import (
WanInputStep, WanAdditionalInputsStep,
WanPrepareFirstFrameLatentsStep,
WanPrepareFirstLastFrameLatentsStep,
WanPrepareLatentsStep, WanPrepareLatentsStep,
WanSetTimestepsStep, WanSetTimestepsStep,
WanTextInputStep,
)
from .decoders import WanImageVaeDecoderStep
from .denoise import (
Wan22DenoiseStep,
Wan22Image2VideoDenoiseStep,
WanDenoiseStep,
WanFLF2VDenoiseStep,
WanImage2VideoDenoiseStep,
)
from .encoders import (
WanFirstLastFrameImageEncoderStep,
WanFirstLastFrameVaeImageEncoderStep,
WanImageCropResizeStep,
WanImageEncoderStep,
WanImageResizeStep,
WanTextEncoderStep,
WanVaeImageEncoderStep,
) )
from .decoders import WanDecodeStep
from .denoise import WanDenoiseStep
from .encoders import WanTextEncoderStep
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# before_denoise: text2vid # wan2.1
class WanBeforeDenoiseStep(SequentialPipelineBlocks): # wan2.1: text2vid
class WanCoreDenoiseStep(SequentialPipelineBlocks):
block_classes = [ block_classes = [
WanInputStep, WanTextInputStep,
WanSetTimestepsStep, WanSetTimestepsStep,
WanPrepareLatentsStep, WanPrepareLatentsStep,
WanDenoiseStep,
] ]
block_names = ["input", "set_timesteps", "prepare_latents"] block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
@property @property
def description(self): def description(self):
return ( return (
"Before denoise step that prepare the inputs for the denoise step.\n" "denoise block that takes encoded conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n" + "This is a sequential pipeline blocks:\n"
+ " - `WanInputStep` is used to adjust the batch size of the model inputs\n" + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n" + " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n" + " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `WanDenoiseStep` is used to denoise the latents\n"
) )
# before_denoise: all task (text2vid,) # wan2.1: image2video
class WanAutoBeforeDenoiseStep(AutoPipelineBlocks): ## image encoder
class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [WanImageResizeStep, WanImageEncoderStep]
block_names = ["image_resize", "image_encoder"]
@property
def description(self):
return "Image2Video Image Encoder step that resize the image and encode the image to generate the image embeddings"
## vae encoder
class WanImage2VideoVaeImageEncoderStep(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [WanImageResizeStep, WanVaeImageEncoderStep]
block_names = ["image_resize", "vae_image_encoder"]
@property
def description(self):
return "Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent representation"
## denoise
class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks):
block_classes = [ block_classes = [
WanBeforeDenoiseStep, WanTextInputStep,
WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"]),
WanSetTimestepsStep,
WanPrepareLatentsStep,
WanPrepareFirstFrameLatentsStep,
WanImage2VideoDenoiseStep,
]
block_names = [
"input",
"additional_inputs",
"set_timesteps",
"prepare_latents",
"prepare_first_frame_latents",
"denoise",
] ]
block_names = ["text2vid"]
block_trigger_inputs = [None]
@property @property
def description(self): def description(self):
return ( return (
"Before denoise step that prepare the inputs for the denoise step.\n" "denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
+ "This is an auto pipeline block that works for text2vid.\n" + "This is a sequential pipeline blocks:\n"
+ " - `WanBeforeDenoiseStep` (text2vid) is used.\n" + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `WanPrepareFirstFrameLatentsStep` is used to prepare the first frame latent conditions\n"
+ " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n"
) )
# denoise: text2vid # wan2.1: FLF2v
class WanAutoDenoiseStep(AutoPipelineBlocks):
## image encoder
class WanFLF2VImageEncoderStep(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameImageEncoderStep]
block_names = ["image_resize", "last_image_resize", "image_encoder"]
@property
def description(self):
return "FLF2V Image Encoder step that resize and encode and encode the first and last frame images to generate the image embeddings"
## vae encoder
class WanFLF2VVaeImageEncoderStep(SequentialPipelineBlocks):
model_name = "wan"
block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameVaeImageEncoderStep]
block_names = ["image_resize", "last_image_resize", "vae_image_encoder"]
@property
def description(self):
return "FLF2V Vae Image Encoder step that resize and encode and encode the first and last frame images to generate the latent conditions"
## denoise
class WanFLF2VCoreDenoiseStep(SequentialPipelineBlocks):
block_classes = [ block_classes = [
WanDenoiseStep, WanTextInputStep,
WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"]),
WanSetTimestepsStep,
WanPrepareLatentsStep,
WanPrepareFirstLastFrameLatentsStep,
WanFLF2VDenoiseStep,
]
block_names = [
"input",
"additional_inputs",
"set_timesteps",
"prepare_latents",
"prepare_first_last_frame_latents",
"denoise",
] ]
block_names = ["denoise"]
block_trigger_inputs = [None]
@property @property
def description(self) -> str: def description(self):
return ( return (
"Denoise step that iteratively denoise the latents. " "denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
"This is a auto pipeline block that works for text2vid tasks.." + "This is a sequential pipeline blocks:\n"
" - `WanDenoiseStep` (denoise) for text2vid tasks." + " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `WanPrepareFirstLastFrameLatentsStep` is used to prepare the latent conditions\n"
+ " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n"
)
# wan2.1: auto blocks
## image encoder
class WanAutoImageEncoderStep(AutoPipelineBlocks):
block_classes = [WanFLF2VImageEncoderStep, WanImage2VideoImageEncoderStep]
block_names = ["flf2v_image_encoder", "image2video_image_encoder"]
block_trigger_inputs = ["last_image", "image"]
@property
def description(self):
return (
"Image Encoder step that encode the image to generate the image embeddings"
+ "This is an auto pipeline block that works for image2video tasks."
+ " - `WanFLF2VImageEncoderStep` (flf2v) is used when `last_image` is provided."
+ " - `WanImage2VideoImageEncoderStep` (image2video) is used when `image` is provided."
+ " - if `last_image` or `image` is not provided, step will be skipped."
) )
# decode: all task (text2img, img2img, inpainting) ## vae encoder
class WanAutoDecodeStep(AutoPipelineBlocks): class WanAutoVaeImageEncoderStep(AutoPipelineBlocks):
block_classes = [WanDecodeStep] block_classes = [WanFLF2VVaeImageEncoderStep, WanImage2VideoVaeImageEncoderStep]
block_names = ["non-inpaint"] block_names = ["flf2v_vae_image_encoder", "image2video_vae_image_encoder"]
block_trigger_inputs = [None] block_trigger_inputs = ["last_image", "image"]
@property @property
def description(self): def description(self):
return "Decode step that decode the denoised latents into videos outputs.\n - `WanDecodeStep`" return (
"Vae Image Encoder step that encode the image to generate the image latents"
+ "This is an auto pipeline block that works for image2video tasks."
+ " - `WanFLF2VVaeImageEncoderStep` (flf2v) is used when `last_image` is provided."
+ " - `WanImage2VideoVaeImageEncoderStep` (image2video) is used when `image` is provided."
+ " - if `last_image` or `image` is not provided, step will be skipped."
)
## denoise
class WanAutoDenoiseStep(AutoPipelineBlocks):
block_classes = [
WanFLF2VCoreDenoiseStep,
WanImage2VideoCoreDenoiseStep,
WanCoreDenoiseStep,
]
block_names = ["flf2v", "image2video", "text2video"]
block_trigger_inputs = ["first_last_frame_latents", "first_frame_latents", None]
# text2vid @property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. "
"This is a auto pipeline block that works for text2video and image2video tasks."
" - `WanCoreDenoiseStep` (text2video) for text2vid tasks."
" - `WanCoreImage2VideoCoreDenoiseStep` (image2video) for image2video tasks."
+ " - if `first_frame_latents` is provided, `WanCoreImage2VideoDenoiseStep` will be used.\n"
+ " - if `first_frame_latents` is not provided, `WanCoreDenoiseStep` will be used.\n"
)
# auto pipeline blocks
class WanAutoBlocks(SequentialPipelineBlocks): class WanAutoBlocks(SequentialPipelineBlocks):
block_classes = [ block_classes = [
WanTextEncoderStep, WanTextEncoderStep,
WanAutoBeforeDenoiseStep, WanAutoImageEncoderStep,
WanAutoVaeImageEncoderStep,
WanAutoDenoiseStep, WanAutoDenoiseStep,
WanAutoDecodeStep, WanImageVaeDecoderStep,
] ]
block_names = [ block_names = [
"text_encoder", "text_encoder",
"before_denoise", "image_encoder",
"vae_image_encoder",
"denoise", "denoise",
"decoder", "decode",
] ]
@property @property
...@@ -116,29 +264,211 @@ class WanAutoBlocks(SequentialPipelineBlocks): ...@@ -116,29 +264,211 @@ class WanAutoBlocks(SequentialPipelineBlocks):
) )
# wan22
# wan2.2: text2vid
## denoise
class Wan22CoreDenoiseStep(SequentialPipelineBlocks):
block_classes = [
WanTextInputStep,
WanSetTimestepsStep,
WanPrepareLatentsStep,
Wan22DenoiseStep,
]
block_names = ["input", "set_timesteps", "prepare_latents", "denoise"]
@property
def description(self):
return (
"denoise block that takes encoded conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `Wan22DenoiseStep` is used to denoise the latents in wan2.2\n"
)
# wan2.2: image2video
## denoise
class Wan22Image2VideoCoreDenoiseStep(SequentialPipelineBlocks):
block_classes = [
WanTextInputStep,
WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"]),
WanSetTimestepsStep,
WanPrepareLatentsStep,
WanPrepareFirstFrameLatentsStep,
Wan22Image2VideoDenoiseStep,
]
block_names = [
"input",
"additional_inputs",
"set_timesteps",
"prepare_latents",
"prepare_first_frame_latents",
"denoise",
]
@property
def description(self):
return (
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `WanPrepareFirstFrameLatentsStep` is used to prepare the first frame latent conditions\n"
+ " - `Wan22Image2VideoDenoiseStep` is used to denoise the latents in wan2.2\n"
)
class Wan22AutoDenoiseStep(AutoPipelineBlocks):
block_classes = [
Wan22Image2VideoCoreDenoiseStep,
Wan22CoreDenoiseStep,
]
block_names = ["image2video", "text2video"]
block_trigger_inputs = ["first_frame_latents", None]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. "
"This is a auto pipeline block that works for text2video and image2video tasks."
" - `Wan22Image2VideoCoreDenoiseStep` (image2video) for image2video tasks."
" - `Wan22CoreDenoiseStep` (text2video) for text2vid tasks."
+ " - if `first_frame_latents` is provided, `Wan22Image2VideoCoreDenoiseStep` will be used.\n"
+ " - if `first_frame_latents` is not provided, `Wan22CoreDenoiseStep` will be used.\n"
)
class Wan22AutoBlocks(SequentialPipelineBlocks):
block_classes = [
WanTextEncoderStep,
WanAutoVaeImageEncoderStep,
Wan22AutoDenoiseStep,
WanImageVaeDecoderStep,
]
block_names = [
"text_encoder",
"vae_image_encoder",
"denoise",
"decode",
]
@property
def description(self):
return (
"Auto Modular pipeline for text-to-video using Wan2.2.\n"
+ "- for text-to-video generation, all you need to provide is `prompt`"
)
# presets for wan2.1 and wan2.2
# YiYi Notes: should we move these to doc?
# wan2.1
TEXT2VIDEO_BLOCKS = InsertableDict( TEXT2VIDEO_BLOCKS = InsertableDict(
[ [
("text_encoder", WanTextEncoderStep), ("text_encoder", WanTextEncoderStep),
("input", WanInputStep), ("input", WanTextInputStep),
("set_timesteps", WanSetTimestepsStep), ("set_timesteps", WanSetTimestepsStep),
("prepare_latents", WanPrepareLatentsStep), ("prepare_latents", WanPrepareLatentsStep),
("denoise", WanDenoiseStep), ("denoise", WanDenoiseStep),
("decode", WanDecodeStep), ("decode", WanImageVaeDecoderStep),
]
)
IMAGE2VIDEO_BLOCKS = InsertableDict(
[
("image_resize", WanImageResizeStep),
("image_encoder", WanImage2VideoImageEncoderStep),
("vae_image_encoder", WanImage2VideoVaeImageEncoderStep),
("input", WanTextInputStep),
("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"])),
("set_timesteps", WanSetTimestepsStep),
("prepare_latents", WanPrepareLatentsStep),
("prepare_first_frame_latents", WanPrepareFirstFrameLatentsStep),
("denoise", WanImage2VideoDenoiseStep),
("decode", WanImageVaeDecoderStep),
] ]
) )
FLF2V_BLOCKS = InsertableDict(
[
("image_resize", WanImageResizeStep),
("last_image_resize", WanImageCropResizeStep),
("image_encoder", WanFLF2VImageEncoderStep),
("vae_image_encoder", WanFLF2VVaeImageEncoderStep),
("input", WanTextInputStep),
("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"])),
("set_timesteps", WanSetTimestepsStep),
("prepare_latents", WanPrepareLatentsStep),
("prepare_first_last_frame_latents", WanPrepareFirstLastFrameLatentsStep),
("denoise", WanFLF2VDenoiseStep),
("decode", WanImageVaeDecoderStep),
]
)
AUTO_BLOCKS = InsertableDict( AUTO_BLOCKS = InsertableDict(
[ [
("text_encoder", WanTextEncoderStep), ("text_encoder", WanTextEncoderStep),
("before_denoise", WanAutoBeforeDenoiseStep), ("image_encoder", WanAutoImageEncoderStep),
("vae_image_encoder", WanAutoVaeImageEncoderStep),
("denoise", WanAutoDenoiseStep), ("denoise", WanAutoDenoiseStep),
("decode", WanAutoDecodeStep), ("decode", WanImageVaeDecoderStep),
] ]
) )
# wan2.2 presets
TEXT2VIDEO_BLOCKS_WAN22 = InsertableDict(
[
("text_encoder", WanTextEncoderStep),
("input", WanTextInputStep),
("set_timesteps", WanSetTimestepsStep),
("prepare_latents", WanPrepareLatentsStep),
("denoise", Wan22DenoiseStep),
("decode", WanImageVaeDecoderStep),
]
)
IMAGE2VIDEO_BLOCKS_WAN22 = InsertableDict(
[
("image_resize", WanImageResizeStep),
("vae_image_encoder", WanImage2VideoVaeImageEncoderStep),
("input", WanTextInputStep),
("set_timesteps", WanSetTimestepsStep),
("prepare_latents", WanPrepareLatentsStep),
("denoise", Wan22DenoiseStep),
("decode", WanImageVaeDecoderStep),
]
)
AUTO_BLOCKS_WAN22 = InsertableDict(
[
("text_encoder", WanTextEncoderStep),
("vae_image_encoder", WanAutoVaeImageEncoderStep),
("denoise", Wan22AutoDenoiseStep),
("decode", WanImageVaeDecoderStep),
]
)
# presets all blocks (wan and wan22)
ALL_BLOCKS = { ALL_BLOCKS = {
"text2video": TEXT2VIDEO_BLOCKS, "wan2.1": {
"auto": AUTO_BLOCKS, "text2video": TEXT2VIDEO_BLOCKS,
"image2video": IMAGE2VIDEO_BLOCKS,
"flf2v": FLF2V_BLOCKS,
"auto": AUTO_BLOCKS,
},
"wan2.2": {
"text2video": TEXT2VIDEO_BLOCKS_WAN22,
"image2video": IMAGE2VIDEO_BLOCKS_WAN22,
"auto": AUTO_BLOCKS_WAN22,
},
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment