Unverified Commit b455dc94 authored by YiYi Xu's avatar YiYi Xu Committed by GitHub
Browse files

[modular] wan! (#12611)

* update, remove intermediaate_inputs

* support image2video

* revert dynamic steps to simplify

* refactor vae encoder block

* support flf2video!

* add support for wan2.2 14B

* style

* Apply suggestions from code review

* input dynamic step -> additiional input step

* up

* fix init

* update dtype
parent 04f9d2bf
......@@ -13,6 +13,8 @@
# limitations under the License.
from typing import Any, Dict, Optional
from ...loaders import WanLoraLoaderMixin
from ...pipelines.pipeline_utils import StableDiffusionMixin
from ...utils import logging
......@@ -35,6 +37,13 @@ class WanModularPipeline(
default_blocks_name = "WanAutoBlocks"
# override the default_blocks_name in base class, which is just return self.default_blocks_name
def get_default_blocks_name(self, config_dict: Optional[Dict[str, Any]]) -> Optional[str]:
if config_dict is not None and "boundary_ratio" in config_dict and config_dict["boundary_ratio"] is not None:
return "Wan22AutoBlocks"
else:
return "WanAutoBlocks"
@property
def default_height(self):
return self.default_sample_height * self.vae_scale_factor_spatial
......@@ -59,6 +68,13 @@ class WanModularPipeline(
def default_sample_num_frames(self):
return 21
@property
def patch_size_spatial(self):
patch_size_spatial = 2
if hasattr(self, "transformer") and self.transformer is not None:
patch_size_spatial = self.transformer.config.patch_size[1]
return patch_size_spatial
@property
def vae_scale_factor_spatial(self):
vae_scale_factor = 8
......@@ -86,3 +102,19 @@ class WanModularPipeline(
if hasattr(self, "vae") and self.vae is not None:
num_channels_latents = self.vae.config.z_dim
return num_channels_latents
@property
def requires_unconditional_embeds(self):
requires_unconditional_embeds = False
if hasattr(self, "guider") and self.guider is not None:
requires_unconditional_embeds = self.guider._enabled and self.guider.num_conditions > 1
return requires_unconditional_embeds
@property
def num_train_timesteps(self):
num_train_timesteps = 1000
if hasattr(self, "scheduler") and self.scheduler is not None:
num_train_timesteps = self.scheduler.config.num_train_timesteps
return num_train_timesteps
......@@ -117,6 +117,7 @@ from .stable_diffusion_xl import (
StableDiffusionXLInpaintPipeline,
StableDiffusionXLPipeline,
)
from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline
from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline
......@@ -214,6 +215,24 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
]
)
AUTO_TEXT2VIDEO_PIPELINES_MAPPING = OrderedDict(
[
("wan", WanPipeline),
]
)
AUTO_IMAGE2VIDEO_PIPELINES_MAPPING = OrderedDict(
[
("wan", WanImageToVideoPipeline),
]
)
AUTO_VIDEO2VIDEO_PIPELINES_MAPPING = OrderedDict(
[
("wan", WanVideoToVideoPipeline),
]
)
_AUTO_TEXT2IMAGE_DECODER_PIPELINES_MAPPING = OrderedDict(
[
("kandinsky", KandinskyPipeline),
......@@ -247,6 +266,9 @@ SUPPORTED_TASKS_MAPPINGS = [
AUTO_TEXT2IMAGE_PIPELINES_MAPPING,
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
AUTO_INPAINT_PIPELINES_MAPPING,
AUTO_TEXT2VIDEO_PIPELINES_MAPPING,
AUTO_IMAGE2VIDEO_PIPELINES_MAPPING,
AUTO_VIDEO2VIDEO_PIPELINES_MAPPING,
_AUTO_TEXT2IMAGE_DECODER_PIPELINES_MAPPING,
_AUTO_IMAGE2IMAGE_DECODER_PIPELINES_MAPPING,
_AUTO_INPAINT_DECODER_PIPELINES_MAPPING,
......
......@@ -182,6 +182,21 @@ class StableDiffusionXLModularPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class Wan22AutoBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class WanAutoBlocks(metaclass=DummyObject):
_backends = ["torch", "transformers"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment