Unverified Commit 04812de2 authored by Yang Yong (雍洋)'s avatar Yang Yong (雍洋) Committed by GitHub
Browse files

Refactor Config System (#338)

parent 6a658f42
...@@ -7,8 +7,10 @@ ...@@ -7,8 +7,10 @@
"self_attn_1_type": "flash_attn3", "self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3", "cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3", "cross_attn_2_type": "flash_attn3",
"seed": 42, "sample_guide_scale": [
"sample_guide_scale": [3.5, 3.5], 3.5,
3.5
],
"sample_shift": 5.0, "sample_shift": 5.0,
"enable_cfg": false, "enable_cfg": false,
"cpu_offload": true, "cpu_offload": true,
...@@ -17,5 +19,10 @@ ...@@ -17,5 +19,10 @@
"vae_cpu_offload": false, "vae_cpu_offload": false,
"use_image_encoder": false, "use_image_encoder": false,
"boundary_step_index": 2, "boundary_step_index": 2,
"denoising_step_list": [1000, 750, 500, 250] "denoising_step_list": [
1000,
750,
500,
250
]
} }
...@@ -7,8 +7,10 @@ ...@@ -7,8 +7,10 @@
"self_attn_1_type": "flash_attn3", "self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3", "cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3", "cross_attn_2_type": "flash_attn3",
"seed": 42, "sample_guide_scale": [
"sample_guide_scale": [3.5, 3.5], 3.5,
3.5
],
"sample_shift": 5.0, "sample_shift": 5.0,
"enable_cfg": false, "enable_cfg": false,
"cpu_offload": true, "cpu_offload": true,
...@@ -17,7 +19,12 @@ ...@@ -17,7 +19,12 @@
"vae_cpu_offload": false, "vae_cpu_offload": false,
"use_image_encoder": false, "use_image_encoder": false,
"boundary_step_index": 2, "boundary_step_index": 2,
"denoising_step_list": [1000, 750, 500, 250], "denoising_step_list": [
1000,
750,
500,
250
],
"mm_config": { "mm_config": {
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm" "mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm"
}, },
......
...@@ -7,8 +7,10 @@ ...@@ -7,8 +7,10 @@
"self_attn_1_type": "flash_attn3", "self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3", "cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3", "cross_attn_2_type": "flash_attn3",
"seed": 42, "sample_guide_scale": [
"sample_guide_scale": [4.0, 3.0], 4.0,
3.0
],
"sample_shift": 12.0, "sample_shift": 12.0,
"enable_cfg": true, "enable_cfg": true,
"cpu_offload": true, "cpu_offload": true,
......
...@@ -7,8 +7,10 @@ ...@@ -7,8 +7,10 @@
"self_attn_1_type": "flash_attn3", "self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3", "cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3", "cross_attn_2_type": "flash_attn3",
"seed": 42, "sample_guide_scale": [
"sample_guide_scale": [4.0, 3.0], 4.0,
3.0
],
"sample_shift": 5.0, "sample_shift": 5.0,
"enable_cfg": false, "enable_cfg": false,
"cpu_offload": true, "cpu_offload": true,
...@@ -16,7 +18,12 @@ ...@@ -16,7 +18,12 @@
"t5_cpu_offload": false, "t5_cpu_offload": false,
"vae_cpu_offload": false, "vae_cpu_offload": false,
"boundary_step_index": 2, "boundary_step_index": 2,
"denoising_step_list": [1000, 750, 500, 250], "denoising_step_list": [
1000,
750,
500,
250
],
"lora_configs": [ "lora_configs": [
{ {
"name": "low_noise_model", "name": "low_noise_model",
......
...@@ -5,11 +5,14 @@ ...@@ -5,11 +5,14 @@
"target_height": 704, "target_height": 704,
"target_width": 1280, "target_width": 1280,
"num_channels_latents": 48, "num_channels_latents": 48,
"vae_stride": [4, 16, 16], "vae_stride": [
4,
16,
16
],
"self_attn_1_type": "flash_attn3", "self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3", "cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3", "cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 5.0, "sample_guide_scale": 5.0,
"sample_shift": 5.0, "sample_shift": 5.0,
"enable_cfg": true, "enable_cfg": true,
......
...@@ -5,11 +5,14 @@ ...@@ -5,11 +5,14 @@
"target_height": 704, "target_height": 704,
"target_width": 1280, "target_width": 1280,
"num_channels_latents": 48, "num_channels_latents": 48,
"vae_stride": [4, 16, 16], "vae_stride": [
4,
16,
16
],
"self_attn_1_type": "flash_attn3", "self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3", "cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3", "cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 5.0, "sample_guide_scale": 5.0,
"sample_shift": 5.0, "sample_shift": 5.0,
"enable_cfg": true, "enable_cfg": true,
......
...@@ -5,11 +5,14 @@ ...@@ -5,11 +5,14 @@
"target_height": 704, "target_height": 704,
"target_width": 1280, "target_width": 1280,
"num_channels_latents": 48, "num_channels_latents": 48,
"vae_stride": [4, 16, 16], "vae_stride": [
4,
16,
16
],
"self_attn_1_type": "flash_attn3", "self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3", "cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3", "cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 5.0, "sample_guide_scale": 5.0,
"sample_shift": 5.0, "sample_shift": 5.0,
"enable_cfg": true, "enable_cfg": true,
......
...@@ -5,11 +5,14 @@ ...@@ -5,11 +5,14 @@
"target_height": 704, "target_height": 704,
"target_width": 1280, "target_width": 1280,
"num_channels_latents": 48, "num_channels_latents": 48,
"vae_stride": [4, 16, 16], "vae_stride": [
4,
16,
16
],
"self_attn_1_type": "flash_attn3", "self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3", "cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3", "cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 5.0, "sample_guide_scale": 5.0,
"sample_shift": 5.0, "sample_shift": 5.0,
"enable_cfg": true, "enable_cfg": true,
......
...@@ -67,7 +67,7 @@ python lightx2v/infer.py \ ...@@ -67,7 +67,7 @@ python lightx2v/infer.py \
--model_path /path/to/model \ --model_path /path/to/model \
--config_json ./configs/video_frame_interpolation/wan_t2v.json \ --config_json ./configs/video_frame_interpolation/wan_t2v.json \
--prompt "A beautiful sunset over the ocean" \ --prompt "A beautiful sunset over the ocean" \
--save_video_path ./output.mp4 --save_result_path ./output.mp4
``` ```
### Configuration Parameters ### Configuration Parameters
...@@ -136,7 +136,7 @@ python lightx2v/infer.py \ ...@@ -136,7 +136,7 @@ python lightx2v/infer.py \
--model_path ./models/wan2.1 \ --model_path ./models/wan2.1 \
--config_json ./wan_t2v_vfi_32fps.json \ --config_json ./wan_t2v_vfi_32fps.json \
--prompt "A cat playing in the garden" \ --prompt "A cat playing in the garden" \
--save_video_path ./output_32fps.mp4 --save_result_path ./output_32fps.mp4
``` ```
### Higher Frame Rate Enhancement ### Higher Frame Rate Enhancement
...@@ -170,7 +170,7 @@ python lightx2v/infer.py \ ...@@ -170,7 +170,7 @@ python lightx2v/infer.py \
--config_json ./wan_i2v_vfi_60fps.json \ --config_json ./wan_i2v_vfi_60fps.json \
--image_path ./input.jpg \ --image_path ./input.jpg \
--prompt "Smooth camera movement" \ --prompt "Smooth camera movement" \
--save_video_path ./output_60fps.mp4 --save_result_path ./output_60fps.mp4
``` ```
## Performance Considerations ## Performance Considerations
......
...@@ -67,7 +67,7 @@ python lightx2v/infer.py \ ...@@ -67,7 +67,7 @@ python lightx2v/infer.py \
--model_path /path/to/model \ --model_path /path/to/model \
--config_json ./configs/video_frame_interpolation/wan_t2v.json \ --config_json ./configs/video_frame_interpolation/wan_t2v.json \
--prompt "美丽的海上日落" \ --prompt "美丽的海上日落" \
--save_video_path ./output.mp4 --save_result_path ./output.mp4
``` ```
### 配置参数说明 ### 配置参数说明
...@@ -136,7 +136,7 @@ python lightx2v/infer.py \ ...@@ -136,7 +136,7 @@ python lightx2v/infer.py \
--model_path ./models/wan2.1 \ --model_path ./models/wan2.1 \
--config_json ./wan_t2v_vfi_32fps.json \ --config_json ./wan_t2v_vfi_32fps.json \
--prompt "一只小猫在花园里玩耍" \ --prompt "一只小猫在花园里玩耍" \
--save_video_path ./output_32fps.mp4 --save_result_path ./output_32fps.mp4
``` ```
### 更高帧率增强 ### 更高帧率增强
...@@ -170,7 +170,7 @@ python lightx2v/infer.py \ ...@@ -170,7 +170,7 @@ python lightx2v/infer.py \
--config_json ./wan_i2v_vfi_60fps.json \ --config_json ./wan_i2v_vfi_60fps.json \
--image_path ./input.jpg \ --image_path ./input.jpg \
--prompt "平滑的相机运动" \ --prompt "平滑的相机运动" \
--save_video_path ./output_60fps.mp4 --save_result_path ./output_60fps.mp4
``` ```
## 性能考虑 ## 性能考虑
......
...@@ -23,7 +23,7 @@ from lightx2v.utils.utils import seed_all ...@@ -23,7 +23,7 @@ from lightx2v.utils.utils import seed_all
class BaseWorker: class BaseWorker:
@ProfilingContext4DebugL1("Init Worker Worker Cost:") @ProfilingContext4DebugL1("Init Worker Worker Cost:")
def __init__(self, args): def __init__(self, args):
args.save_video_path = "" args.save_result_path = ""
config = set_config(args) config = set_config(args)
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}") logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
seed_all(config.seed) seed_all(config.seed)
...@@ -49,7 +49,7 @@ class BaseWorker: ...@@ -49,7 +49,7 @@ class BaseWorker:
self.runner.config["prompt"] = params["prompt"] self.runner.config["prompt"] = params["prompt"]
self.runner.config["negative_prompt"] = params.get("negative_prompt", "") self.runner.config["negative_prompt"] = params.get("negative_prompt", "")
self.runner.config["image_path"] = params.get("image_path", "") self.runner.config["image_path"] = params.get("image_path", "")
self.runner.config["save_video_path"] = params.get("save_video_path", "") self.runner.config["save_result_path"] = params.get("save_result_path", "")
self.runner.config["seed"] = params.get("seed", self.fixed_config.get("seed", 42)) self.runner.config["seed"] = params.get("seed", self.fixed_config.get("seed", 42))
self.runner.config["audio_path"] = params.get("audio_path", "") self.runner.config["audio_path"] = params.get("audio_path", "")
...@@ -92,7 +92,7 @@ class BaseWorker: ...@@ -92,7 +92,7 @@ class BaseWorker:
if stream_video_path is not None: if stream_video_path is not None:
tmp_video_path = stream_video_path tmp_video_path = stream_video_path
params["save_video_path"] = tmp_video_path params["save_result_path"] = tmp_video_path
return tmp_video_path, output_video_path return tmp_video_path, output_video_path
async def prepare_dit_inputs(self, inputs, data_manager): async def prepare_dit_inputs(self, inputs, data_manager):
......
...@@ -17,6 +17,7 @@ from lightx2v.models.runners.wan.wan_sf_runner import WanSFRunner # noqa: F401 ...@@ -17,6 +17,7 @@ from lightx2v.models.runners.wan.wan_sf_runner import WanSFRunner # noqa: F401
from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner # noqa: F401 from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner # noqa: F401
from lightx2v.models.runners.wan.wan_vace_runner import WanVaceRunner # noqa: F401 from lightx2v.models.runners.wan.wan_vace_runner import WanVaceRunner # noqa: F401
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from lightx2v.utils.input_info import set_input_info
from lightx2v.utils.profiler import * from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.set_config import print_config, set_config, set_parallel_config from lightx2v.utils.set_config import print_config, set_config, set_parallel_config
...@@ -24,15 +25,15 @@ from lightx2v.utils.utils import seed_all ...@@ -24,15 +25,15 @@ from lightx2v.utils.utils import seed_all
def init_runner(config): def init_runner(config):
seed_all(config.seed)
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
runner = RUNNER_REGISTER[config.model_cls](config) runner = RUNNER_REGISTER[config["model_cls"]](config)
runner.init_modules() runner.init_modules()
return runner return runner
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=42, help="The seed for random generator")
parser.add_argument( parser.add_argument(
"--model_cls", "--model_cls",
type=str, type=str,
...@@ -58,7 +59,7 @@ def main(): ...@@ -58,7 +59,7 @@ def main():
default="wan2.1", default="wan2.1",
) )
parser.add_argument("--task", type=str, choices=["t2v", "i2v", "t2i", "i2i", "flf2v", "vace", "animate"], default="t2v") parser.add_argument("--task", type=str, choices=["t2v", "i2v", "t2i", "i2i", "flf2v", "vace", "animate", "s2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--sf_model_path", type=str, required=False) parser.add_argument("--sf_model_path", type=str, required=False)
parser.add_argument("--config_json", type=str, required=True) parser.add_argument("--config_json", type=str, required=True)
...@@ -91,13 +92,16 @@ def main(): ...@@ -91,13 +92,16 @@ def main():
help="The file of the source mask. Default None.", help="The file of the source mask. Default None.",
) )
parser.add_argument("--save_video_path", type=str, default=None, help="The path to save video path/file") parser.add_argument("--save_result_path", type=str, default=None, help="The path to save video path/file")
parser.add_argument("--return_result_tensor", action="store_true", help="Whether to return result tensor. (Useful for comfyui)")
args = parser.parse_args() args = parser.parse_args()
seed_all(args.seed)
# set config # set config
config = set_config(args) config = set_config(args)
if config.parallel: if config["parallel"]:
dist.init_process_group(backend="nccl") dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank()) torch.cuda.set_device(dist.get_rank())
set_parallel_config(config) set_parallel_config(config)
...@@ -106,7 +110,8 @@ def main(): ...@@ -106,7 +110,8 @@ def main():
with ProfilingContext4DebugL1("Total Cost"): with ProfilingContext4DebugL1("Total Cost"):
runner = init_runner(config) runner = init_runner(config)
runner.run_pipeline() input_info = set_input_info(args)
runner.run_pipeline(input_info)
# Clean up distributed process group # Clean up distributed process group
if dist.is_initialized(): if dist.is_initialized():
......
...@@ -2,6 +2,7 @@ import os ...@@ -2,6 +2,7 @@ import os
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn.functional as F
from loguru import logger from loguru import logger
from lightx2v.models.networks.wan.infer.audio.post_infer import WanAudioPostInfer from lightx2v.models.networks.wan.infer.audio.post_infer import WanAudioPostInfer
...@@ -35,11 +36,11 @@ class WanAudioModel(WanModel): ...@@ -35,11 +36,11 @@ class WanAudioModel(WanModel):
raise ValueError(f"Unsupported quant_scheme: {self.config.get('adapter_quant_scheme', None)}") raise ValueError(f"Unsupported quant_scheme: {self.config.get('adapter_quant_scheme', None)}")
else: else:
adapter_model_name = "audio_adapter_model.safetensors" adapter_model_name = "audio_adapter_model.safetensors"
self.config.adapter_model_path = os.path.join(self.config.model_path, adapter_model_name) self.config["adapter_model_path"] = os.path.join(self.config["model_path"], adapter_model_name)
adapter_offload = self.config.get("cpu_offload", False) adapter_offload = self.config.get("cpu_offload", False)
load_from_rank0 = self.config.get("load_from_rank0", False) load_from_rank0 = self.config.get("load_from_rank0", False)
self.adapter_weights_dict = load_weights(self.config.adapter_model_path, cpu_offload=adapter_offload, remove_key="audio", load_from_rank0=load_from_rank0) self.adapter_weights_dict = load_weights(self.config["adapter_model_path"], cpu_offload=adapter_offload, remove_key="audio", load_from_rank0=load_from_rank0)
if not adapter_offload: if not adapter_offload:
if not dist.is_initialized() or not load_from_rank0: if not dist.is_initialized() or not load_from_rank0:
for key in self.adapter_weights_dict: for key in self.adapter_weights_dict:
...@@ -51,17 +52,17 @@ class WanAudioModel(WanModel): ...@@ -51,17 +52,17 @@ class WanAudioModel(WanModel):
self.post_infer_class = WanAudioPostInfer self.post_infer_class = WanAudioPostInfer
self.transformer_infer_class = WanAudioTransformerInfer self.transformer_infer_class = WanAudioTransformerInfer
def get_graph_name(self, shape, audio_num): def get_graph_name(self, shape, audio_num, with_mask=True):
return f"graph_{shape[0]}x{shape[1]}_{audio_num}audio" return f"graph_{shape[0]}x{shape[1]}_audio_num_{audio_num}_mask_{with_mask}"
def start_compile(self, shape, audio_num): def start_compile(self, shape, audio_num, with_mask=True):
graph_name = self.get_graph_name(shape, audio_num) graph_name = self.get_graph_name(shape, audio_num, with_mask)
logger.info(f"[Compile] Compile shape: {shape}, audio_num:{audio_num}, graph_name: {graph_name}") logger.info(f"[Compile] Compile shape: {shape}, audio_num:{audio_num}, graph_name: {graph_name}")
target_video_length = self.config.get("target_video_length", 81) target_video_length = self.config.get("target_video_length", 81)
latents_length = (target_video_length - 1) // 16 * 4 + 1 latents_length = (target_video_length - 1) // 16 * 4 + 1
latents_h = shape[0] // self.config.vae_stride[1] latents_h = shape[0] // self.config["vae_stride"][1]
latents_w = shape[1] // self.config.vae_stride[2] latents_w = shape[1] // self.config["vae_stride"][2]
new_inputs = {} new_inputs = {}
new_inputs["text_encoder_output"] = {} new_inputs["text_encoder_output"] = {}
...@@ -73,7 +74,11 @@ class WanAudioModel(WanModel): ...@@ -73,7 +74,11 @@ class WanAudioModel(WanModel):
new_inputs["image_encoder_output"]["vae_encoder_out"] = torch.randn(16, 1, latents_h, latents_w, dtype=torch.bfloat16).cuda() new_inputs["image_encoder_output"]["vae_encoder_out"] = torch.randn(16, 1, latents_h, latents_w, dtype=torch.bfloat16).cuda()
new_inputs["audio_encoder_output"] = torch.randn(audio_num, latents_length, 128, 1024, dtype=torch.bfloat16).cuda() new_inputs["audio_encoder_output"] = torch.randn(audio_num, latents_length, 128, 1024, dtype=torch.bfloat16).cuda()
new_inputs["person_mask_latens"] = torch.zeros(audio_num, 1, (latents_h // 2), (latents_w // 2), dtype=torch.int8).cuda() if with_mask:
new_inputs["person_mask_latens"] = torch.zeros(audio_num, 1, (latents_h // 2), (latents_w // 2), dtype=torch.int8).cuda()
else:
assert audio_num == 1, "audio_num must be 1 when with_mask is False"
new_inputs["person_mask_latens"] = None
new_inputs["previmg_encoder_output"] = {} new_inputs["previmg_encoder_output"] = {}
new_inputs["previmg_encoder_output"]["prev_latents"] = torch.randn(16, latents_length, latents_h, latents_w, dtype=torch.bfloat16).cuda() new_inputs["previmg_encoder_output"]["prev_latents"] = torch.randn(16, latents_length, latents_h, latents_w, dtype=torch.bfloat16).cuda()
...@@ -90,19 +95,21 @@ class WanAudioModel(WanModel): ...@@ -90,19 +95,21 @@ class WanAudioModel(WanModel):
self.enable_compile_mode("_infer_cond_uncond") self.enable_compile_mode("_infer_cond_uncond")
if self.cpu_offload: if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == 0 and "wan2.2_moe" not in self.config.model_cls: if self.offload_granularity == "model" and self.scheduler.step_index == 0 and "wan2.2_moe" not in self.config["model_cls"]:
self.to_cuda() self.to_cuda()
elif self.offload_granularity != "model": elif self.offload_granularity != "model":
self.pre_weight.to_cuda() self.pre_weight.to_cuda()
self.transformer_weights.non_block_weights_to_cuda() self.transformer_weights.non_block_weights_to_cuda()
max_audio_num_num = self.config.get("compile_max_audios", 1) max_audio_num_num = self.config.get("compile_max_audios", 3)
for audio_num in range(1, max_audio_num_num + 1): for audio_num in range(1, max_audio_num_num + 1):
for shape in compile_shapes: for shape in compile_shapes:
self.start_compile(shape, audio_num) self.start_compile(shape, audio_num, with_mask=True)
if audio_num == 1:
self.start_compile(shape, audio_num, with_mask=False)
if self.cpu_offload: if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1 and "wan2.2_moe" not in self.config.model_cls: if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1 and "wan2.2_moe" not in self.config["model_cls"]:
self.to_cpu() self.to_cpu()
elif self.offload_granularity != "model": elif self.offload_granularity != "model":
self.pre_weight.to_cpu() self.pre_weight.to_cpu()
...@@ -115,9 +122,10 @@ class WanAudioModel(WanModel): ...@@ -115,9 +122,10 @@ class WanAudioModel(WanModel):
for shape in compile_shapes: for shape in compile_shapes:
assert shape in [[480, 832], [544, 960], [720, 1280], [832, 480], [960, 544], [1280, 720], [480, 480], [576, 576], [704, 704], [960, 960]] assert shape in [[480, 832], [544, 960], [720, 1280], [832, 480], [960, 544], [1280, 720], [480, 480], [576, 576], [704, 704], [960, 960]]
def select_graph_for_compile(self): def select_graph_for_compile(self, input_info):
logger.info(f"tgt_h, tgt_w : {self.config.get('tgt_h')}, {self.config.get('tgt_w')}, audio_num: {self.config.get('audio_num')}") logger.info(f"target_h, target_w : {input_info.target_shape[0]}, {input_info.target_shape[1]}, audio_num: {input_info.audio_num}")
self.select_graph("_infer_cond_uncond", f"graph_{self.config.get('tgt_h')}x{self.config.get('tgt_w')}_{self.config.get('audio_num')}audio") graph_name = self.get_graph_name(input_info.target_shape, input_info.audio_num, with_mask=input_info.with_mask)
self.select_graph("_infer_cond_uncond", graph_name)
logger.info(f"[Compile] Compile status: {self.get_compile_status()}") logger.info(f"[Compile] Compile status: {self.get_compile_status()}")
@torch.no_grad() @torch.no_grad()
...@@ -138,7 +146,7 @@ class WanAudioModel(WanModel): ...@@ -138,7 +146,7 @@ class WanAudioModel(WanModel):
if person_mask_latens is not None: if person_mask_latens is not None:
pre_infer_out.adapter_args["person_mask_latens"] = torch.chunk(person_mask_latens, world_size, dim=1)[cur_rank] pre_infer_out.adapter_args["person_mask_latens"] = torch.chunk(person_mask_latens, world_size, dim=1)[cur_rank]
if self.config["model_cls"] in ["wan2.2", "wan2.2_audio"] and self.config["task"] == "i2v": if self.config["model_cls"] in ["wan2.2", "wan2.2_audio"] and self.config["task"] in ["i2v", "s2v"]:
embed, embed0 = pre_infer_out.embed, pre_infer_out.embed0 embed, embed0 = pre_infer_out.embed, pre_infer_out.embed0
padding_size = (world_size - (embed.shape[0] % world_size)) % world_size padding_size = (world_size - (embed.shape[0] % world_size)) % world_size
if padding_size > 0: if padding_size > 0:
......
...@@ -33,7 +33,7 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -33,7 +33,7 @@ class WanAudioPreInfer(WanPreInfer):
infer_condition, latents, timestep_input = self.scheduler.infer_condition, self.scheduler.latents, self.scheduler.timestep_input infer_condition, latents, timestep_input = self.scheduler.infer_condition, self.scheduler.latents, self.scheduler.timestep_input
prev_latents = inputs["previmg_encoder_output"]["prev_latents"] prev_latents = inputs["previmg_encoder_output"]["prev_latents"]
hidden_states = latents hidden_states = latents
if self.config.model_cls != "wan2.2_audio": if self.config["model_cls"] != "wan2.2_audio":
prev_mask = inputs["previmg_encoder_output"]["prev_mask"] prev_mask = inputs["previmg_encoder_output"]["prev_mask"]
hidden_states = torch.cat([hidden_states, prev_mask, prev_latents], dim=0) hidden_states = torch.cat([hidden_states, prev_mask, prev_latents], dim=0)
...@@ -101,7 +101,7 @@ class WanAudioPreInfer(WanPreInfer): ...@@ -101,7 +101,7 @@ class WanAudioPreInfer(WanPreInfer):
del out del out
torch.cuda.empty_cache() torch.cuda.empty_cache()
if self.task == "i2v" and self.config.get("use_image_encoder", True): if self.task in ["i2v", "s2v"] and self.config.get("use_image_encoder", True):
context_clip = weights.proj_0.apply(clip_fea) context_clip = weights.proj_0.apply(clip_fea)
if self.clean_cuda_cache: if self.clean_cuda_cache:
del clip_fea del clip_fea
......
...@@ -140,7 +140,7 @@ class WanTransformerInferCausVid(WanOffloadTransformerInfer): ...@@ -140,7 +140,7 @@ class WanTransformerInferCausVid(WanOffloadTransformerInfer):
def infer_cross_attn(self, weights, x, context, block_idx): def infer_cross_attn(self, weights, x, context, block_idx):
norm3_out = weights.norm3.apply(x) norm3_out = weights.norm3.apply(x)
if self.task == "i2v": if self.task in ["i2v", "s2v"]:
context_img = context[:257] context_img = context[:257]
context = context[257:] context = context[257:]
...@@ -169,7 +169,7 @@ class WanTransformerInferCausVid(WanOffloadTransformerInfer): ...@@ -169,7 +169,7 @@ class WanTransformerInferCausVid(WanOffloadTransformerInfer):
model_cls=self.config["model_cls"], model_cls=self.config["model_cls"],
) )
if self.task == "i2v": if self.task in ["i2v", "s2v"]:
k_img = weights.cross_attn_norm_k_img.apply(weights.cross_attn_k_img.apply(context_img)).view(-1, n, d) k_img = weights.cross_attn_norm_k_img.apply(weights.cross_attn_k_img.apply(context_img)).view(-1, n, d)
v_img = weights.cross_attn_v_img.apply(context_img).view(-1, n, d) v_img = weights.cross_attn_v_img.apply(context_img).view(-1, n, d)
......
...@@ -25,22 +25,22 @@ class WanTransformerInferCaching(WanOffloadTransformerInfer): ...@@ -25,22 +25,22 @@ class WanTransformerInferCaching(WanOffloadTransformerInfer):
class WanTransformerInferTeaCaching(WanTransformerInferCaching): class WanTransformerInferTeaCaching(WanTransformerInferCaching):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.teacache_thresh = config.teacache_thresh self.teacache_thresh = config["teacache_thresh"]
self.accumulated_rel_l1_distance_even = 0 self.accumulated_rel_l1_distance_even = 0
self.previous_e0_even = None self.previous_e0_even = None
self.previous_residual_even = None self.previous_residual_even = None
self.accumulated_rel_l1_distance_odd = 0 self.accumulated_rel_l1_distance_odd = 0
self.previous_e0_odd = None self.previous_e0_odd = None
self.previous_residual_odd = None self.previous_residual_odd = None
self.use_ret_steps = config.use_ret_steps self.use_ret_steps = config["use_ret_steps"]
if self.use_ret_steps: if self.use_ret_steps:
self.coefficients = self.config.coefficients[0] self.coefficients = self.config["coefficients"][0]
self.ret_steps = 5 self.ret_steps = 5
self.cutoff_steps = self.config.infer_steps self.cutoff_steps = self.config["infer_steps"]
else: else:
self.coefficients = self.config.coefficients[1] self.coefficients = self.config["coefficients"][1]
self.ret_steps = 1 self.ret_steps = 1
self.cutoff_steps = self.config.infer_steps - 1 self.cutoff_steps = self.config["infer_steps"] - 1
# calculate should_calc # calculate should_calc
@torch.no_grad() @torch.no_grad()
...@@ -216,7 +216,7 @@ class WanTransformerInferTaylorCaching(WanTransformerInferCaching, BaseTaylorCac ...@@ -216,7 +216,7 @@ class WanTransformerInferTaylorCaching(WanTransformerInferCaching, BaseTaylorCac
else: else:
x = self.infer_using_cache(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context) x = self.infer_using_cache(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
if self.config.enable_cfg: if self.config["enable_cfg"]:
self.switch_status() self.switch_status()
return x return x
...@@ -353,7 +353,7 @@ class WanTransformerInferAdaCaching(WanTransformerInferCaching): ...@@ -353,7 +353,7 @@ class WanTransformerInferAdaCaching(WanTransformerInferCaching):
else: else:
x = self.infer_using_cache(xt) x = self.infer_using_cache(xt)
if self.config.enable_cfg: if self.config["enable_cfg"]:
self.switch_status() self.switch_status()
return x return x
...@@ -515,7 +515,7 @@ class AdaArgs: ...@@ -515,7 +515,7 @@ class AdaArgs:
# Moreg related attributes # Moreg related attributes
self.previous_moreg = 1.0 self.previous_moreg = 1.0
self.moreg_strides = [1] self.moreg_strides = [1]
self.moreg_steps = [int(0.1 * config.infer_steps), int(0.9 * config.infer_steps)] self.moreg_steps = [int(0.1 * config["infer_steps"]), int(0.9 * config["infer_steps"])]
self.moreg_hyp = [0.385, 8, 1, 2] self.moreg_hyp = [0.385, 8, 1, 2]
self.mograd_mul = 10 self.mograd_mul = 10
self.spatial_dim = 1536 self.spatial_dim = 1536
...@@ -525,7 +525,7 @@ class WanTransformerInferCustomCaching(WanTransformerInferCaching, BaseTaylorCac ...@@ -525,7 +525,7 @@ class WanTransformerInferCustomCaching(WanTransformerInferCaching, BaseTaylorCac
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.cnt = 0 self.cnt = 0
self.teacache_thresh = config.teacache_thresh self.teacache_thresh = config["teacache_thresh"]
self.accumulated_rel_l1_distance_even = 0 self.accumulated_rel_l1_distance_even = 0
self.previous_e0_even = None self.previous_e0_even = None
self.previous_residual_even = None self.previous_residual_even = None
...@@ -534,15 +534,15 @@ class WanTransformerInferCustomCaching(WanTransformerInferCaching, BaseTaylorCac ...@@ -534,15 +534,15 @@ class WanTransformerInferCustomCaching(WanTransformerInferCaching, BaseTaylorCac
self.previous_residual_odd = None self.previous_residual_odd = None
self.cache_even = {} self.cache_even = {}
self.cache_odd = {} self.cache_odd = {}
self.use_ret_steps = config.use_ret_steps self.use_ret_steps = config["use_ret_steps"]
if self.use_ret_steps: if self.use_ret_steps:
self.coefficients = self.config.coefficients[0] self.coefficients = self.config["coefficients"][0]
self.ret_steps = 5 * 2 self.ret_steps = 5 * 2
self.cutoff_steps = self.config.infer_steps * 2 self.cutoff_steps = self.config["infer_steps"] * 2
else: else:
self.coefficients = self.config.coefficients[1] self.coefficients = self.config["coefficients"][1]
self.ret_steps = 1 * 2 self.ret_steps = 1 * 2
self.cutoff_steps = self.config.infer_steps * 2 - 2 self.cutoff_steps = self.config["infer_steps"] * 2 - 2
# 1. get taylor step_diff when there is two caching_records in scheduler # 1. get taylor step_diff when there is two caching_records in scheduler
def get_taylor_step_diff(self): def get_taylor_step_diff(self):
...@@ -625,7 +625,7 @@ class WanTransformerInferCustomCaching(WanTransformerInferCaching, BaseTaylorCac ...@@ -625,7 +625,7 @@ class WanTransformerInferCustomCaching(WanTransformerInferCaching, BaseTaylorCac
else: else:
x = self.infer_using_cache(x) x = self.infer_using_cache(x)
if self.config.enable_cfg: if self.config["enable_cfg"]:
self.switch_status() self.switch_status()
self.cnt += 1 self.cnt += 1
...@@ -690,12 +690,12 @@ class WanTransformerInferFirstBlock(WanTransformerInferCaching): ...@@ -690,12 +690,12 @@ class WanTransformerInferFirstBlock(WanTransformerInferCaching):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.residual_diff_threshold = config.residual_diff_threshold self.residual_diff_threshold = config["residual_diff_threshold"]
self.prev_first_block_residual_even = None self.prev_first_block_residual_even = None
self.prev_remaining_blocks_residual_even = None self.prev_remaining_blocks_residual_even = None
self.prev_first_block_residual_odd = None self.prev_first_block_residual_odd = None
self.prev_remaining_blocks_residual_odd = None self.prev_remaining_blocks_residual_odd = None
self.downsample_factor = self.config.downsample_factor self.downsample_factor = self.config["downsample_factor"]
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
ori_x = x.clone() ori_x = x.clone()
...@@ -727,7 +727,7 @@ class WanTransformerInferFirstBlock(WanTransformerInferCaching): ...@@ -727,7 +727,7 @@ class WanTransformerInferFirstBlock(WanTransformerInferCaching):
else: else:
x = self.infer_using_cache(x) x = self.infer_using_cache(x)
if self.config.enable_cfg: if self.config["enable_cfg"]:
self.switch_status() self.switch_status()
return x return x
...@@ -795,12 +795,12 @@ class WanTransformerInferDualBlock(WanTransformerInferCaching): ...@@ -795,12 +795,12 @@ class WanTransformerInferDualBlock(WanTransformerInferCaching):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.residual_diff_threshold = config.residual_diff_threshold self.residual_diff_threshold = config["residual_diff_threshold"]
self.prev_front_blocks_residual_even = None self.prev_front_blocks_residual_even = None
self.prev_middle_blocks_residual_even = None self.prev_middle_blocks_residual_even = None
self.prev_front_blocks_residual_odd = None self.prev_front_blocks_residual_odd = None
self.prev_middle_blocks_residual_odd = None self.prev_middle_blocks_residual_odd = None
self.downsample_factor = self.config.downsample_factor self.downsample_factor = self.config["downsample_factor"]
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
ori_x = x.clone() ori_x = x.clone()
...@@ -854,7 +854,7 @@ class WanTransformerInferDualBlock(WanTransformerInferCaching): ...@@ -854,7 +854,7 @@ class WanTransformerInferDualBlock(WanTransformerInferCaching):
context, context,
) )
if self.config.enable_cfg: if self.config["enable_cfg"]:
self.switch_status() self.switch_status()
return x return x
...@@ -921,8 +921,8 @@ class WanTransformerInferDualBlock(WanTransformerInferCaching): ...@@ -921,8 +921,8 @@ class WanTransformerInferDualBlock(WanTransformerInferCaching):
class WanTransformerInferDynamicBlock(WanTransformerInferCaching): class WanTransformerInferDynamicBlock(WanTransformerInferCaching):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.residual_diff_threshold = config.residual_diff_threshold self.residual_diff_threshold = config["residual_diff_threshold"]
self.downsample_factor = self.config.downsample_factor self.downsample_factor = self.config["downsample_factor"]
self.block_in_cache_even = {i: None for i in range(self.blocks_num)} self.block_in_cache_even = {i: None for i in range(self.blocks_num)}
self.block_residual_cache_even = {i: None for i in range(self.blocks_num)} self.block_residual_cache_even = {i: None for i in range(self.blocks_num)}
...@@ -992,10 +992,10 @@ class WanTransformerInferDynamicBlock(WanTransformerInferCaching): ...@@ -992,10 +992,10 @@ class WanTransformerInferDynamicBlock(WanTransformerInferCaching):
class WanTransformerInferMagCaching(WanTransformerInferCaching): class WanTransformerInferMagCaching(WanTransformerInferCaching):
def __init__(self, config): def __init__(self, config):
super().__init__(config) super().__init__(config)
self.magcache_thresh = config.magcache_thresh self.magcache_thresh = config["magcache_thresh"]
self.K = config.magcache_K self.K = config["magcache_K"]
self.retention_ratio = config.magcache_retention_ratio self.retention_ratio = config["magcache_retention_ratio"]
self.mag_ratios = np.array(config.magcache_ratios) self.mag_ratios = np.array(config["magcache_ratios"])
# {True: cond_param, False: uncond_param} # {True: cond_param, False: uncond_param}
self.accumulated_err = {True: 0.0, False: 0.0} self.accumulated_err = {True: 0.0, False: 0.0}
self.accumulated_steps = {True: 0, False: 0} self.accumulated_steps = {True: 0, False: 0}
...@@ -1011,10 +1011,10 @@ class WanTransformerInferMagCaching(WanTransformerInferCaching): ...@@ -1011,10 +1011,10 @@ class WanTransformerInferMagCaching(WanTransformerInferCaching):
step_index = self.scheduler.step_index step_index = self.scheduler.step_index
infer_condition = self.scheduler.infer_condition infer_condition = self.scheduler.infer_condition
if self.config.magcache_calibration: if self.config["magcache_calibration"]:
skip_forward = False skip_forward = False
else: else:
if step_index >= int(self.config.infer_steps * self.retention_ratio): if step_index >= int(self.config["infer_steps"] * self.retention_ratio):
# conditional and unconditional in one list # conditional and unconditional in one list
cur_mag_ratio = self.mag_ratios[0][step_index] if infer_condition else self.mag_ratios[1][step_index] cur_mag_ratio = self.mag_ratios[0][step_index] if infer_condition else self.mag_ratios[1][step_index]
# magnitude ratio between current step and the cached step # magnitude ratio between current step and the cached step
...@@ -1054,7 +1054,7 @@ class WanTransformerInferMagCaching(WanTransformerInferCaching): ...@@ -1054,7 +1054,7 @@ class WanTransformerInferMagCaching(WanTransformerInferCaching):
if self.config["cpu_offload"]: if self.config["cpu_offload"]:
previous_residual = previous_residual.cpu() previous_residual = previous_residual.cpu()
if self.config.magcache_calibration and step_index >= 1: if self.config["magcache_calibration"] and step_index >= 1:
norm_ratio = ((previous_residual.norm(dim=-1) / self.residual_cache[infer_condition].norm(dim=-1)).mean()).item() norm_ratio = ((previous_residual.norm(dim=-1) / self.residual_cache[infer_condition].norm(dim=-1)).mean()).item()
norm_std = (previous_residual.norm(dim=-1) / self.residual_cache[infer_condition].norm(dim=-1)).std().item() norm_std = (previous_residual.norm(dim=-1) / self.residual_cache[infer_condition].norm(dim=-1)).std().item()
cos_dis = (1 - F.cosine_similarity(previous_residual, self.residual_cache[infer_condition], dim=-1, eps=1e-8)).mean().item() cos_dis = (1 - F.cosine_similarity(previous_residual, self.residual_cache[infer_condition], dim=-1, eps=1e-8)).mean().item()
...@@ -1083,7 +1083,7 @@ class WanTransformerInferMagCaching(WanTransformerInferCaching): ...@@ -1083,7 +1083,7 @@ class WanTransformerInferMagCaching(WanTransformerInferCaching):
self.accumulated_steps = {True: 0, False: 0} self.accumulated_steps = {True: 0, False: 0}
self.accumulated_ratio = {True: 1.0, False: 1.0} self.accumulated_ratio = {True: 1.0, False: 1.0}
self.residual_cache = {True: None, False: None} self.residual_cache = {True: None, False: None}
if self.config.magcache_calibration: if self.config["magcache_calibration"]:
print("norm ratio") print("norm ratio")
print(self.norm_ratio) print(self.norm_ratio)
print("norm std") print("norm std")
......
...@@ -41,7 +41,7 @@ class WanPreInfer: ...@@ -41,7 +41,7 @@ class WanPreInfer:
else: else:
context = inputs["text_encoder_output"]["context_null"] context = inputs["text_encoder_output"]["context_null"]
if self.task in ["i2v", "flf2v", "animate"]: if self.task in ["i2v", "flf2v", "animate", "s2v"]:
if self.config.get("use_image_encoder", True): if self.config.get("use_image_encoder", True):
clip_fea = inputs["image_encoder_output"]["clip_encoder_out"] clip_fea = inputs["image_encoder_output"]["clip_encoder_out"]
......
...@@ -39,12 +39,12 @@ class WanSFTransformerInfer(WanTransformerInfer): ...@@ -39,12 +39,12 @@ class WanSFTransformerInfer(WanTransformerInfer):
else: else:
self.device = torch.device("cuda") self.device = torch.device("cuda")
self.dtype = torch.bfloat16 self.dtype = torch.bfloat16
sf_config = self.config.sf_config sf_config = self.config["sf_config"]
self.local_attn_size = sf_config.local_attn_size self.local_attn_size = sf_config["local_attn_size"]
self.max_attention_size = 32760 if self.local_attn_size == -1 else self.local_attn_size * 1560 self.max_attention_size = 32760 if self.local_attn_size == -1 else self.local_attn_size * 1560
self.num_frame_per_block = sf_config.num_frame_per_block self.num_frame_per_block = sf_config["num_frame_per_block"]
self.num_transformer_blocks = sf_config.num_transformer_blocks self.num_transformer_blocks = sf_config["num_transformer_blocks"]
self.frame_seq_length = sf_config.frame_seq_length self.frame_seq_length = sf_config["frame_seq_length"]
self._initialize_kv_cache(self.device, self.dtype) self._initialize_kv_cache(self.device, self.dtype)
self._initialize_crossattn_cache(self.device, self.dtype) self._initialize_crossattn_cache(self.device, self.dtype)
......
...@@ -11,13 +11,13 @@ from .utils import apply_rotary_emb, apply_rotary_emb_chunk, compute_freqs, comp ...@@ -11,13 +11,13 @@ from .utils import apply_rotary_emb, apply_rotary_emb_chunk, compute_freqs, comp
class WanTransformerInfer(BaseTransformerInfer): class WanTransformerInfer(BaseTransformerInfer):
def __init__(self, config): def __init__(self, config):
self.config = config self.config = config
self.task = config.task self.task = config["task"]
self.attention_type = config.get("attention_type", "flash_attn2") self.attention_type = config.get("attention_type", "flash_attn2")
self.blocks_num = config.num_layers self.blocks_num = config["num_layers"]
self.phases_num = 3 self.phases_num = 3
self.has_post_adapter = False self.has_post_adapter = False
self.num_heads = config.num_heads self.num_heads = config["num_heads"]
self.head_dim = config.dim // config.num_heads self.head_dim = config["dim"] // config["num_heads"]
self.window_size = config.get("window_size", (-1, -1)) self.window_size = config.get("window_size", (-1, -1))
self.parallel_attention = None self.parallel_attention = None
if config.get("rotary_chunk", False): if config.get("rotary_chunk", False):
...@@ -203,7 +203,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -203,7 +203,7 @@ class WanTransformerInfer(BaseTransformerInfer):
x.add_(y_out * gate_msa.squeeze()) x.add_(y_out * gate_msa.squeeze())
norm3_out = phase.norm3.apply(x) norm3_out = phase.norm3.apply(x)
if self.task in ["i2v", "flf2v", "animate"] and self.config.get("use_image_encoder", True): if self.task in ["i2v", "flf2v", "animate", "s2v"] and self.config.get("use_image_encoder", True):
context_img = context[:257] context_img = context[:257]
context = context[257:] context = context[257:]
else: else:
...@@ -211,7 +211,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -211,7 +211,7 @@ class WanTransformerInfer(BaseTransformerInfer):
if self.sensitive_layer_dtype != self.infer_dtype: if self.sensitive_layer_dtype != self.infer_dtype:
context = context.to(self.infer_dtype) context = context.to(self.infer_dtype)
if self.task in ["i2v", "flf2v", "animate"] and self.config.get("use_image_encoder", True): if self.task in ["i2v", "flf2v", "animate", "s2v"] and self.config.get("use_image_encoder", True):
context_img = context_img.to(self.infer_dtype) context_img = context_img.to(self.infer_dtype)
n, d = self.num_heads, self.head_dim n, d = self.num_heads, self.head_dim
...@@ -234,7 +234,7 @@ class WanTransformerInfer(BaseTransformerInfer): ...@@ -234,7 +234,7 @@ class WanTransformerInfer(BaseTransformerInfer):
model_cls=self.config["model_cls"], model_cls=self.config["model_cls"],
) )
if self.task in ["i2v", "flf2v", "animate"] and self.config.get("use_image_encoder", True) and context_img is not None: if self.task in ["i2v", "flf2v", "animate", "s2v"] and self.config.get("use_image_encoder", True) and context_img is not None:
k_img = phase.cross_attn_norm_k_img.apply(phase.cross_attn_k_img.apply(context_img)).view(-1, n, d) k_img = phase.cross_attn_norm_k_img.apply(phase.cross_attn_k_img.apply(context_img)).view(-1, n, d)
v_img = phase.cross_attn_v_img.apply(context_img).view(-1, n, d) v_img = phase.cross_attn_v_img.apply(context_img).view(-1, n, d)
......
...@@ -62,15 +62,15 @@ class WanModel(CompiledMethodsMixin): ...@@ -62,15 +62,15 @@ class WanModel(CompiledMethodsMixin):
self.init_empty_model = False self.init_empty_model = False
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False) self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default" self.dit_quantized = self.config["mm_config"].get("mm_type", "Default") != "Default"
if self.dit_quantized: if self.dit_quantized:
dit_quant_scheme = self.config.mm_config.get("mm_type").split("-")[1] dit_quant_scheme = self.config["mm_config"].get("mm_type").split("-")[1]
if self.config.model_cls == "wan2.1_distill": if self.config["model_cls"] == "wan2.1_distill":
dit_quant_scheme = "distill_" + dit_quant_scheme dit_quant_scheme = "distill_" + dit_quant_scheme
if dit_quant_scheme == "gguf": if dit_quant_scheme == "gguf":
self.dit_quantized_ckpt = find_gguf_model_path(config, "dit_quantized_ckpt", subdir=dit_quant_scheme) self.dit_quantized_ckpt = find_gguf_model_path(config, "dit_quantized_ckpt", subdir=dit_quant_scheme)
self.config.use_gguf = True self.config["use_gguf"] = True
else: else:
self.dit_quantized_ckpt = find_hf_model_path( self.dit_quantized_ckpt = find_hf_model_path(
config, config,
...@@ -87,7 +87,7 @@ class WanModel(CompiledMethodsMixin): ...@@ -87,7 +87,7 @@ class WanModel(CompiledMethodsMixin):
self.dit_quantized_ckpt = None self.dit_quantized_ckpt = None
assert not self.config.get("lazy_load", False) assert not self.config.get("lazy_load", False)
self.weight_auto_quant = self.config.mm_config.get("weight_auto_quant", False) self.weight_auto_quant = self.config["mm_config"].get("weight_auto_quant", False)
if self.dit_quantized: if self.dit_quantized:
assert self.weight_auto_quant or self.dit_quantized_ckpt is not None assert self.weight_auto_quant or self.dit_quantized_ckpt is not None
...@@ -158,7 +158,7 @@ class WanModel(CompiledMethodsMixin): ...@@ -158,7 +158,7 @@ class WanModel(CompiledMethodsMixin):
weight_dict = {} weight_dict = {}
for file_path in safetensors_files: for file_path in safetensors_files:
if self.config.get("adapter_model_path", None) is not None: if self.config.get("adapter_model_path", None) is not None:
if self.config.adapter_model_path == file_path: if self.config["adapter_model_path"] == file_path:
continue continue
file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer) file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
weight_dict.update(file_weights) weight_dict.update(file_weights)
...@@ -367,7 +367,7 @@ class WanModel(CompiledMethodsMixin): ...@@ -367,7 +367,7 @@ class WanModel(CompiledMethodsMixin):
@torch.no_grad() @torch.no_grad()
def infer(self, inputs): def infer(self, inputs):
if self.cpu_offload: if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == 0 and "wan2.2_moe" not in self.config.model_cls: if self.offload_granularity == "model" and self.scheduler.step_index == 0 and "wan2.2_moe" not in self.config["model_cls"]:
self.to_cuda() self.to_cuda()
elif self.offload_granularity != "model": elif self.offload_granularity != "model":
self.pre_weight.to_cuda() self.pre_weight.to_cuda()
...@@ -400,7 +400,7 @@ class WanModel(CompiledMethodsMixin): ...@@ -400,7 +400,7 @@ class WanModel(CompiledMethodsMixin):
self.scheduler.noise_pred = self._infer_cond_uncond(inputs, infer_condition=True) self.scheduler.noise_pred = self._infer_cond_uncond(inputs, infer_condition=True)
if self.cpu_offload: if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1 and "wan2.2_moe" not in self.config.model_cls: if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1 and "wan2.2_moe" not in self.config["model_cls"]:
self.to_cpu() self.to_cpu()
elif self.offload_granularity != "model": elif self.offload_granularity != "model":
self.pre_weight.to_cpu() self.pre_weight.to_cpu()
...@@ -441,7 +441,7 @@ class WanModel(CompiledMethodsMixin): ...@@ -441,7 +441,7 @@ class WanModel(CompiledMethodsMixin):
pre_infer_out.x = torch.chunk(x, world_size, dim=0)[cur_rank] pre_infer_out.x = torch.chunk(x, world_size, dim=0)[cur_rank]
if self.config["model_cls"] in ["wan2.2", "wan2.2_audio"] and self.config["task"] == "i2v": if self.config["model_cls"] in ["wan2.2", "wan2.2_audio"] and self.config["task"] in ["i2v", "s2v"]:
embed, embed0 = pre_infer_out.embed, pre_infer_out.embed0 embed, embed0 = pre_infer_out.embed, pre_infer_out.embed0
padding_size = (world_size - (embed.shape[0] % world_size)) % world_size padding_size = (world_size - (embed.shape[0] % world_size)) % world_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment