"tests/schedulers/test_scheduler_euler_ancestral.py" did not exist on "6a7a5467cab6df8bb24b20a7ad3f2223c1a2e8de"
Unverified Commit 04812de2 authored by Yang Yong (雍洋)'s avatar Yang Yong (雍洋) Committed by GitHub
Browse files

Refactor Config System (#338)

parent 6a658f42
......@@ -7,8 +7,10 @@
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": [3.5, 3.5],
"sample_guide_scale": [
3.5,
3.5
],
"sample_shift": 5.0,
"enable_cfg": false,
"cpu_offload": true,
......@@ -17,5 +19,10 @@
"vae_cpu_offload": false,
"use_image_encoder": false,
"boundary_step_index": 2,
"denoising_step_list": [1000, 750, 500, 250]
"denoising_step_list": [
1000,
750,
500,
250
]
}
......@@ -7,8 +7,10 @@
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": [3.5, 3.5],
"sample_guide_scale": [
3.5,
3.5
],
"sample_shift": 5.0,
"enable_cfg": false,
"cpu_offload": true,
......@@ -17,7 +19,12 @@
"vae_cpu_offload": false,
"use_image_encoder": false,
"boundary_step_index": 2,
"denoising_step_list": [1000, 750, 500, 250],
"denoising_step_list": [
1000,
750,
500,
250
],
"mm_config": {
"mm_type": "W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm"
},
......
......@@ -7,8 +7,10 @@
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": [4.0, 3.0],
"sample_guide_scale": [
4.0,
3.0
],
"sample_shift": 12.0,
"enable_cfg": true,
"cpu_offload": true,
......
......@@ -7,8 +7,10 @@
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": [4.0, 3.0],
"sample_guide_scale": [
4.0,
3.0
],
"sample_shift": 5.0,
"enable_cfg": false,
"cpu_offload": true,
......@@ -16,7 +18,12 @@
"t5_cpu_offload": false,
"vae_cpu_offload": false,
"boundary_step_index": 2,
"denoising_step_list": [1000, 750, 500, 250],
"denoising_step_list": [
1000,
750,
500,
250
],
"lora_configs": [
{
"name": "low_noise_model",
......
......@@ -5,11 +5,14 @@
"target_height": 704,
"target_width": 1280,
"num_channels_latents": 48,
"vae_stride": [4, 16, 16],
"vae_stride": [
4,
16,
16
],
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 5.0,
"sample_shift": 5.0,
"enable_cfg": true,
......
......@@ -5,11 +5,14 @@
"target_height": 704,
"target_width": 1280,
"num_channels_latents": 48,
"vae_stride": [4, 16, 16],
"vae_stride": [
4,
16,
16
],
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 5.0,
"sample_shift": 5.0,
"enable_cfg": true,
......
......@@ -5,11 +5,14 @@
"target_height": 704,
"target_width": 1280,
"num_channels_latents": 48,
"vae_stride": [4, 16, 16],
"vae_stride": [
4,
16,
16
],
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 5.0,
"sample_shift": 5.0,
"enable_cfg": true,
......
......@@ -5,11 +5,14 @@
"target_height": 704,
"target_width": 1280,
"num_channels_latents": 48,
"vae_stride": [4, 16, 16],
"vae_stride": [
4,
16,
16
],
"self_attn_1_type": "flash_attn3",
"cross_attn_1_type": "flash_attn3",
"cross_attn_2_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 5.0,
"sample_shift": 5.0,
"enable_cfg": true,
......
......@@ -67,7 +67,7 @@ python lightx2v/infer.py \
--model_path /path/to/model \
--config_json ./configs/video_frame_interpolation/wan_t2v.json \
--prompt "A beautiful sunset over the ocean" \
--save_video_path ./output.mp4
--save_result_path ./output.mp4
```
### Configuration Parameters
......@@ -136,7 +136,7 @@ python lightx2v/infer.py \
--model_path ./models/wan2.1 \
--config_json ./wan_t2v_vfi_32fps.json \
--prompt "A cat playing in the garden" \
--save_video_path ./output_32fps.mp4
--save_result_path ./output_32fps.mp4
```
### Higher Frame Rate Enhancement
......@@ -170,7 +170,7 @@ python lightx2v/infer.py \
--config_json ./wan_i2v_vfi_60fps.json \
--image_path ./input.jpg \
--prompt "Smooth camera movement" \
--save_video_path ./output_60fps.mp4
--save_result_path ./output_60fps.mp4
```
## Performance Considerations
......
......@@ -67,7 +67,7 @@ python lightx2v/infer.py \
--model_path /path/to/model \
--config_json ./configs/video_frame_interpolation/wan_t2v.json \
--prompt "美丽的海上日落" \
--save_video_path ./output.mp4
--save_result_path ./output.mp4
```
### 配置参数说明
......@@ -136,7 +136,7 @@ python lightx2v/infer.py \
--model_path ./models/wan2.1 \
--config_json ./wan_t2v_vfi_32fps.json \
--prompt "一只小猫在花园里玩耍" \
--save_video_path ./output_32fps.mp4
--save_result_path ./output_32fps.mp4
```
### 更高帧率增强
......@@ -170,7 +170,7 @@ python lightx2v/infer.py \
--config_json ./wan_i2v_vfi_60fps.json \
--image_path ./input.jpg \
--prompt "平滑的相机运动" \
--save_video_path ./output_60fps.mp4
--save_result_path ./output_60fps.mp4
```
## 性能考虑
......
......@@ -23,7 +23,7 @@ from lightx2v.utils.utils import seed_all
class BaseWorker:
@ProfilingContext4DebugL1("Init Worker Worker Cost:")
def __init__(self, args):
args.save_video_path = ""
args.save_result_path = ""
config = set_config(args)
logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
seed_all(config.seed)
......@@ -49,7 +49,7 @@ class BaseWorker:
self.runner.config["prompt"] = params["prompt"]
self.runner.config["negative_prompt"] = params.get("negative_prompt", "")
self.runner.config["image_path"] = params.get("image_path", "")
self.runner.config["save_video_path"] = params.get("save_video_path", "")
self.runner.config["save_result_path"] = params.get("save_result_path", "")
self.runner.config["seed"] = params.get("seed", self.fixed_config.get("seed", 42))
self.runner.config["audio_path"] = params.get("audio_path", "")
......@@ -92,7 +92,7 @@ class BaseWorker:
if stream_video_path is not None:
tmp_video_path = stream_video_path
params["save_video_path"] = tmp_video_path
params["save_result_path"] = tmp_video_path
return tmp_video_path, output_video_path
async def prepare_dit_inputs(self, inputs, data_manager):
......
......@@ -17,6 +17,7 @@ from lightx2v.models.runners.wan.wan_sf_runner import WanSFRunner # noqa: F401
from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner # noqa: F401
from lightx2v.models.runners.wan.wan_vace_runner import WanVaceRunner # noqa: F401
from lightx2v.utils.envs import *
from lightx2v.utils.input_info import set_input_info
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.utils.set_config import print_config, set_config, set_parallel_config
......@@ -24,15 +25,15 @@ from lightx2v.utils.utils import seed_all
def init_runner(config):
seed_all(config.seed)
torch.set_grad_enabled(False)
runner = RUNNER_REGISTER[config.model_cls](config)
runner = RUNNER_REGISTER[config["model_cls"]](config)
runner.init_modules()
return runner
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=42, help="The seed for random generator")
parser.add_argument(
"--model_cls",
type=str,
......@@ -58,7 +59,7 @@ def main():
default="wan2.1",
)
parser.add_argument("--task", type=str, choices=["t2v", "i2v", "t2i", "i2i", "flf2v", "vace", "animate"], default="t2v")
parser.add_argument("--task", type=str, choices=["t2v", "i2v", "t2i", "i2i", "flf2v", "vace", "animate", "s2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--sf_model_path", type=str, required=False)
parser.add_argument("--config_json", type=str, required=True)
......@@ -91,13 +92,16 @@ def main():
help="The file of the source mask. Default None.",
)
parser.add_argument("--save_video_path", type=str, default=None, help="The path to save video path/file")
parser.add_argument("--save_result_path", type=str, default=None, help="The path to save video path/file")
parser.add_argument("--return_result_tensor", action="store_true", help="Whether to return result tensor. (Useful for comfyui)")
args = parser.parse_args()
seed_all(args.seed)
# set config
config = set_config(args)
if config.parallel:
if config["parallel"]:
dist.init_process_group(backend="nccl")
torch.cuda.set_device(dist.get_rank())
set_parallel_config(config)
......@@ -106,7 +110,8 @@ def main():
with ProfilingContext4DebugL1("Total Cost"):
runner = init_runner(config)
runner.run_pipeline()
input_info = set_input_info(args)
runner.run_pipeline(input_info)
# Clean up distributed process group
if dist.is_initialized():
......
......@@ -2,6 +2,7 @@ import os
import torch
import torch.distributed as dist
import torch.nn.functional as F
from loguru import logger
from lightx2v.models.networks.wan.infer.audio.post_infer import WanAudioPostInfer
......@@ -35,11 +36,11 @@ class WanAudioModel(WanModel):
raise ValueError(f"Unsupported quant_scheme: {self.config.get('adapter_quant_scheme', None)}")
else:
adapter_model_name = "audio_adapter_model.safetensors"
self.config.adapter_model_path = os.path.join(self.config.model_path, adapter_model_name)
self.config["adapter_model_path"] = os.path.join(self.config["model_path"], adapter_model_name)
adapter_offload = self.config.get("cpu_offload", False)
load_from_rank0 = self.config.get("load_from_rank0", False)
self.adapter_weights_dict = load_weights(self.config.adapter_model_path, cpu_offload=adapter_offload, remove_key="audio", load_from_rank0=load_from_rank0)
self.adapter_weights_dict = load_weights(self.config["adapter_model_path"], cpu_offload=adapter_offload, remove_key="audio", load_from_rank0=load_from_rank0)
if not adapter_offload:
if not dist.is_initialized() or not load_from_rank0:
for key in self.adapter_weights_dict:
......@@ -51,17 +52,17 @@ class WanAudioModel(WanModel):
self.post_infer_class = WanAudioPostInfer
self.transformer_infer_class = WanAudioTransformerInfer
def get_graph_name(self, shape, audio_num):
return f"graph_{shape[0]}x{shape[1]}_{audio_num}audio"
def get_graph_name(self, shape, audio_num, with_mask=True):
return f"graph_{shape[0]}x{shape[1]}_audio_num_{audio_num}_mask_{with_mask}"
def start_compile(self, shape, audio_num):
graph_name = self.get_graph_name(shape, audio_num)
def start_compile(self, shape, audio_num, with_mask=True):
graph_name = self.get_graph_name(shape, audio_num, with_mask)
logger.info(f"[Compile] Compile shape: {shape}, audio_num:{audio_num}, graph_name: {graph_name}")
target_video_length = self.config.get("target_video_length", 81)
latents_length = (target_video_length - 1) // 16 * 4 + 1
latents_h = shape[0] // self.config.vae_stride[1]
latents_w = shape[1] // self.config.vae_stride[2]
latents_h = shape[0] // self.config["vae_stride"][1]
latents_w = shape[1] // self.config["vae_stride"][2]
new_inputs = {}
new_inputs["text_encoder_output"] = {}
......@@ -73,7 +74,11 @@ class WanAudioModel(WanModel):
new_inputs["image_encoder_output"]["vae_encoder_out"] = torch.randn(16, 1, latents_h, latents_w, dtype=torch.bfloat16).cuda()
new_inputs["audio_encoder_output"] = torch.randn(audio_num, latents_length, 128, 1024, dtype=torch.bfloat16).cuda()
new_inputs["person_mask_latens"] = torch.zeros(audio_num, 1, (latents_h // 2), (latents_w // 2), dtype=torch.int8).cuda()
if with_mask:
new_inputs["person_mask_latens"] = torch.zeros(audio_num, 1, (latents_h // 2), (latents_w // 2), dtype=torch.int8).cuda()
else:
assert audio_num == 1, "audio_num must be 1 when with_mask is False"
new_inputs["person_mask_latens"] = None
new_inputs["previmg_encoder_output"] = {}
new_inputs["previmg_encoder_output"]["prev_latents"] = torch.randn(16, latents_length, latents_h, latents_w, dtype=torch.bfloat16).cuda()
......@@ -90,19 +95,21 @@ class WanAudioModel(WanModel):
self.enable_compile_mode("_infer_cond_uncond")
if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == 0 and "wan2.2_moe" not in self.config.model_cls:
if self.offload_granularity == "model" and self.scheduler.step_index == 0 and "wan2.2_moe" not in self.config["model_cls"]:
self.to_cuda()
elif self.offload_granularity != "model":
self.pre_weight.to_cuda()
self.transformer_weights.non_block_weights_to_cuda()
max_audio_num_num = self.config.get("compile_max_audios", 1)
max_audio_num_num = self.config.get("compile_max_audios", 3)
for audio_num in range(1, max_audio_num_num + 1):
for shape in compile_shapes:
self.start_compile(shape, audio_num)
self.start_compile(shape, audio_num, with_mask=True)
if audio_num == 1:
self.start_compile(shape, audio_num, with_mask=False)
if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1 and "wan2.2_moe" not in self.config.model_cls:
if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1 and "wan2.2_moe" not in self.config["model_cls"]:
self.to_cpu()
elif self.offload_granularity != "model":
self.pre_weight.to_cpu()
......@@ -115,9 +122,10 @@ class WanAudioModel(WanModel):
for shape in compile_shapes:
assert shape in [[480, 832], [544, 960], [720, 1280], [832, 480], [960, 544], [1280, 720], [480, 480], [576, 576], [704, 704], [960, 960]]
def select_graph_for_compile(self):
logger.info(f"tgt_h, tgt_w : {self.config.get('tgt_h')}, {self.config.get('tgt_w')}, audio_num: {self.config.get('audio_num')}")
self.select_graph("_infer_cond_uncond", f"graph_{self.config.get('tgt_h')}x{self.config.get('tgt_w')}_{self.config.get('audio_num')}audio")
def select_graph_for_compile(self, input_info):
logger.info(f"target_h, target_w : {input_info.target_shape[0]}, {input_info.target_shape[1]}, audio_num: {input_info.audio_num}")
graph_name = self.get_graph_name(input_info.target_shape, input_info.audio_num, with_mask=input_info.with_mask)
self.select_graph("_infer_cond_uncond", graph_name)
logger.info(f"[Compile] Compile status: {self.get_compile_status()}")
@torch.no_grad()
......@@ -138,7 +146,7 @@ class WanAudioModel(WanModel):
if person_mask_latens is not None:
pre_infer_out.adapter_args["person_mask_latens"] = torch.chunk(person_mask_latens, world_size, dim=1)[cur_rank]
if self.config["model_cls"] in ["wan2.2", "wan2.2_audio"] and self.config["task"] == "i2v":
if self.config["model_cls"] in ["wan2.2", "wan2.2_audio"] and self.config["task"] in ["i2v", "s2v"]:
embed, embed0 = pre_infer_out.embed, pre_infer_out.embed0
padding_size = (world_size - (embed.shape[0] % world_size)) % world_size
if padding_size > 0:
......
......@@ -33,7 +33,7 @@ class WanAudioPreInfer(WanPreInfer):
infer_condition, latents, timestep_input = self.scheduler.infer_condition, self.scheduler.latents, self.scheduler.timestep_input
prev_latents = inputs["previmg_encoder_output"]["prev_latents"]
hidden_states = latents
if self.config.model_cls != "wan2.2_audio":
if self.config["model_cls"] != "wan2.2_audio":
prev_mask = inputs["previmg_encoder_output"]["prev_mask"]
hidden_states = torch.cat([hidden_states, prev_mask, prev_latents], dim=0)
......@@ -101,7 +101,7 @@ class WanAudioPreInfer(WanPreInfer):
del out
torch.cuda.empty_cache()
if self.task == "i2v" and self.config.get("use_image_encoder", True):
if self.task in ["i2v", "s2v"] and self.config.get("use_image_encoder", True):
context_clip = weights.proj_0.apply(clip_fea)
if self.clean_cuda_cache:
del clip_fea
......
......@@ -140,7 +140,7 @@ class WanTransformerInferCausVid(WanOffloadTransformerInfer):
def infer_cross_attn(self, weights, x, context, block_idx):
norm3_out = weights.norm3.apply(x)
if self.task == "i2v":
if self.task in ["i2v", "s2v"]:
context_img = context[:257]
context = context[257:]
......@@ -169,7 +169,7 @@ class WanTransformerInferCausVid(WanOffloadTransformerInfer):
model_cls=self.config["model_cls"],
)
if self.task == "i2v":
if self.task in ["i2v", "s2v"]:
k_img = weights.cross_attn_norm_k_img.apply(weights.cross_attn_k_img.apply(context_img)).view(-1, n, d)
v_img = weights.cross_attn_v_img.apply(context_img).view(-1, n, d)
......
......@@ -25,22 +25,22 @@ class WanTransformerInferCaching(WanOffloadTransformerInfer):
class WanTransformerInferTeaCaching(WanTransformerInferCaching):
def __init__(self, config):
super().__init__(config)
self.teacache_thresh = config.teacache_thresh
self.teacache_thresh = config["teacache_thresh"]
self.accumulated_rel_l1_distance_even = 0
self.previous_e0_even = None
self.previous_residual_even = None
self.accumulated_rel_l1_distance_odd = 0
self.previous_e0_odd = None
self.previous_residual_odd = None
self.use_ret_steps = config.use_ret_steps
self.use_ret_steps = config["use_ret_steps"]
if self.use_ret_steps:
self.coefficients = self.config.coefficients[0]
self.coefficients = self.config["coefficients"][0]
self.ret_steps = 5
self.cutoff_steps = self.config.infer_steps
self.cutoff_steps = self.config["infer_steps"]
else:
self.coefficients = self.config.coefficients[1]
self.coefficients = self.config["coefficients"][1]
self.ret_steps = 1
self.cutoff_steps = self.config.infer_steps - 1
self.cutoff_steps = self.config["infer_steps"] - 1
# calculate should_calc
@torch.no_grad()
......@@ -216,7 +216,7 @@ class WanTransformerInferTaylorCaching(WanTransformerInferCaching, BaseTaylorCac
else:
x = self.infer_using_cache(weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context)
if self.config.enable_cfg:
if self.config["enable_cfg"]:
self.switch_status()
return x
......@@ -353,7 +353,7 @@ class WanTransformerInferAdaCaching(WanTransformerInferCaching):
else:
x = self.infer_using_cache(xt)
if self.config.enable_cfg:
if self.config["enable_cfg"]:
self.switch_status()
return x
......@@ -515,7 +515,7 @@ class AdaArgs:
# Moreg related attributes
self.previous_moreg = 1.0
self.moreg_strides = [1]
self.moreg_steps = [int(0.1 * config.infer_steps), int(0.9 * config.infer_steps)]
self.moreg_steps = [int(0.1 * config["infer_steps"]), int(0.9 * config["infer_steps"])]
self.moreg_hyp = [0.385, 8, 1, 2]
self.mograd_mul = 10
self.spatial_dim = 1536
......@@ -525,7 +525,7 @@ class WanTransformerInferCustomCaching(WanTransformerInferCaching, BaseTaylorCac
def __init__(self, config):
super().__init__(config)
self.cnt = 0
self.teacache_thresh = config.teacache_thresh
self.teacache_thresh = config["teacache_thresh"]
self.accumulated_rel_l1_distance_even = 0
self.previous_e0_even = None
self.previous_residual_even = None
......@@ -534,15 +534,15 @@ class WanTransformerInferCustomCaching(WanTransformerInferCaching, BaseTaylorCac
self.previous_residual_odd = None
self.cache_even = {}
self.cache_odd = {}
self.use_ret_steps = config.use_ret_steps
self.use_ret_steps = config["use_ret_steps"]
if self.use_ret_steps:
self.coefficients = self.config.coefficients[0]
self.coefficients = self.config["coefficients"][0]
self.ret_steps = 5 * 2
self.cutoff_steps = self.config.infer_steps * 2
self.cutoff_steps = self.config["infer_steps"] * 2
else:
self.coefficients = self.config.coefficients[1]
self.coefficients = self.config["coefficients"][1]
self.ret_steps = 1 * 2
self.cutoff_steps = self.config.infer_steps * 2 - 2
self.cutoff_steps = self.config["infer_steps"] * 2 - 2
# 1. get taylor step_diff when there is two caching_records in scheduler
def get_taylor_step_diff(self):
......@@ -625,7 +625,7 @@ class WanTransformerInferCustomCaching(WanTransformerInferCaching, BaseTaylorCac
else:
x = self.infer_using_cache(x)
if self.config.enable_cfg:
if self.config["enable_cfg"]:
self.switch_status()
self.cnt += 1
......@@ -690,12 +690,12 @@ class WanTransformerInferFirstBlock(WanTransformerInferCaching):
def __init__(self, config):
super().__init__(config)
self.residual_diff_threshold = config.residual_diff_threshold
self.residual_diff_threshold = config["residual_diff_threshold"]
self.prev_first_block_residual_even = None
self.prev_remaining_blocks_residual_even = None
self.prev_first_block_residual_odd = None
self.prev_remaining_blocks_residual_odd = None
self.downsample_factor = self.config.downsample_factor
self.downsample_factor = self.config["downsample_factor"]
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
ori_x = x.clone()
......@@ -727,7 +727,7 @@ class WanTransformerInferFirstBlock(WanTransformerInferCaching):
else:
x = self.infer_using_cache(x)
if self.config.enable_cfg:
if self.config["enable_cfg"]:
self.switch_status()
return x
......@@ -795,12 +795,12 @@ class WanTransformerInferDualBlock(WanTransformerInferCaching):
def __init__(self, config):
super().__init__(config)
self.residual_diff_threshold = config.residual_diff_threshold
self.residual_diff_threshold = config["residual_diff_threshold"]
self.prev_front_blocks_residual_even = None
self.prev_middle_blocks_residual_even = None
self.prev_front_blocks_residual_odd = None
self.prev_middle_blocks_residual_odd = None
self.downsample_factor = self.config.downsample_factor
self.downsample_factor = self.config["downsample_factor"]
def infer(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
ori_x = x.clone()
......@@ -854,7 +854,7 @@ class WanTransformerInferDualBlock(WanTransformerInferCaching):
context,
)
if self.config.enable_cfg:
if self.config["enable_cfg"]:
self.switch_status()
return x
......@@ -921,8 +921,8 @@ class WanTransformerInferDualBlock(WanTransformerInferCaching):
class WanTransformerInferDynamicBlock(WanTransformerInferCaching):
def __init__(self, config):
super().__init__(config)
self.residual_diff_threshold = config.residual_diff_threshold
self.downsample_factor = self.config.downsample_factor
self.residual_diff_threshold = config["residual_diff_threshold"]
self.downsample_factor = self.config["downsample_factor"]
self.block_in_cache_even = {i: None for i in range(self.blocks_num)}
self.block_residual_cache_even = {i: None for i in range(self.blocks_num)}
......@@ -992,10 +992,10 @@ class WanTransformerInferDynamicBlock(WanTransformerInferCaching):
class WanTransformerInferMagCaching(WanTransformerInferCaching):
def __init__(self, config):
super().__init__(config)
self.magcache_thresh = config.magcache_thresh
self.K = config.magcache_K
self.retention_ratio = config.magcache_retention_ratio
self.mag_ratios = np.array(config.magcache_ratios)
self.magcache_thresh = config["magcache_thresh"]
self.K = config["magcache_K"]
self.retention_ratio = config["magcache_retention_ratio"]
self.mag_ratios = np.array(config["magcache_ratios"])
# {True: cond_param, False: uncond_param}
self.accumulated_err = {True: 0.0, False: 0.0}
self.accumulated_steps = {True: 0, False: 0}
......@@ -1011,10 +1011,10 @@ class WanTransformerInferMagCaching(WanTransformerInferCaching):
step_index = self.scheduler.step_index
infer_condition = self.scheduler.infer_condition
if self.config.magcache_calibration:
if self.config["magcache_calibration"]:
skip_forward = False
else:
if step_index >= int(self.config.infer_steps * self.retention_ratio):
if step_index >= int(self.config["infer_steps"] * self.retention_ratio):
# conditional and unconditional in one list
cur_mag_ratio = self.mag_ratios[0][step_index] if infer_condition else self.mag_ratios[1][step_index]
# magnitude ratio between current step and the cached step
......@@ -1054,7 +1054,7 @@ class WanTransformerInferMagCaching(WanTransformerInferCaching):
if self.config["cpu_offload"]:
previous_residual = previous_residual.cpu()
if self.config.magcache_calibration and step_index >= 1:
if self.config["magcache_calibration"] and step_index >= 1:
norm_ratio = ((previous_residual.norm(dim=-1) / self.residual_cache[infer_condition].norm(dim=-1)).mean()).item()
norm_std = (previous_residual.norm(dim=-1) / self.residual_cache[infer_condition].norm(dim=-1)).std().item()
cos_dis = (1 - F.cosine_similarity(previous_residual, self.residual_cache[infer_condition], dim=-1, eps=1e-8)).mean().item()
......@@ -1083,7 +1083,7 @@ class WanTransformerInferMagCaching(WanTransformerInferCaching):
self.accumulated_steps = {True: 0, False: 0}
self.accumulated_ratio = {True: 1.0, False: 1.0}
self.residual_cache = {True: None, False: None}
if self.config.magcache_calibration:
if self.config["magcache_calibration"]:
print("norm ratio")
print(self.norm_ratio)
print("norm std")
......
......@@ -41,7 +41,7 @@ class WanPreInfer:
else:
context = inputs["text_encoder_output"]["context_null"]
if self.task in ["i2v", "flf2v", "animate"]:
if self.task in ["i2v", "flf2v", "animate", "s2v"]:
if self.config.get("use_image_encoder", True):
clip_fea = inputs["image_encoder_output"]["clip_encoder_out"]
......
......@@ -39,12 +39,12 @@ class WanSFTransformerInfer(WanTransformerInfer):
else:
self.device = torch.device("cuda")
self.dtype = torch.bfloat16
sf_config = self.config.sf_config
self.local_attn_size = sf_config.local_attn_size
sf_config = self.config["sf_config"]
self.local_attn_size = sf_config["local_attn_size"]
self.max_attention_size = 32760 if self.local_attn_size == -1 else self.local_attn_size * 1560
self.num_frame_per_block = sf_config.num_frame_per_block
self.num_transformer_blocks = sf_config.num_transformer_blocks
self.frame_seq_length = sf_config.frame_seq_length
self.num_frame_per_block = sf_config["num_frame_per_block"]
self.num_transformer_blocks = sf_config["num_transformer_blocks"]
self.frame_seq_length = sf_config["frame_seq_length"]
self._initialize_kv_cache(self.device, self.dtype)
self._initialize_crossattn_cache(self.device, self.dtype)
......
......@@ -11,13 +11,13 @@ from .utils import apply_rotary_emb, apply_rotary_emb_chunk, compute_freqs, comp
class WanTransformerInfer(BaseTransformerInfer):
def __init__(self, config):
self.config = config
self.task = config.task
self.task = config["task"]
self.attention_type = config.get("attention_type", "flash_attn2")
self.blocks_num = config.num_layers
self.blocks_num = config["num_layers"]
self.phases_num = 3
self.has_post_adapter = False
self.num_heads = config.num_heads
self.head_dim = config.dim // config.num_heads
self.num_heads = config["num_heads"]
self.head_dim = config["dim"] // config["num_heads"]
self.window_size = config.get("window_size", (-1, -1))
self.parallel_attention = None
if config.get("rotary_chunk", False):
......@@ -203,7 +203,7 @@ class WanTransformerInfer(BaseTransformerInfer):
x.add_(y_out * gate_msa.squeeze())
norm3_out = phase.norm3.apply(x)
if self.task in ["i2v", "flf2v", "animate"] and self.config.get("use_image_encoder", True):
if self.task in ["i2v", "flf2v", "animate", "s2v"] and self.config.get("use_image_encoder", True):
context_img = context[:257]
context = context[257:]
else:
......@@ -211,7 +211,7 @@ class WanTransformerInfer(BaseTransformerInfer):
if self.sensitive_layer_dtype != self.infer_dtype:
context = context.to(self.infer_dtype)
if self.task in ["i2v", "flf2v", "animate"] and self.config.get("use_image_encoder", True):
if self.task in ["i2v", "flf2v", "animate", "s2v"] and self.config.get("use_image_encoder", True):
context_img = context_img.to(self.infer_dtype)
n, d = self.num_heads, self.head_dim
......@@ -234,7 +234,7 @@ class WanTransformerInfer(BaseTransformerInfer):
model_cls=self.config["model_cls"],
)
if self.task in ["i2v", "flf2v", "animate"] and self.config.get("use_image_encoder", True) and context_img is not None:
if self.task in ["i2v", "flf2v", "animate", "s2v"] and self.config.get("use_image_encoder", True) and context_img is not None:
k_img = phase.cross_attn_norm_k_img.apply(phase.cross_attn_k_img.apply(context_img)).view(-1, n, d)
v_img = phase.cross_attn_v_img.apply(context_img).view(-1, n, d)
......
......@@ -62,15 +62,15 @@ class WanModel(CompiledMethodsMixin):
self.init_empty_model = False
self.clean_cuda_cache = self.config.get("clean_cuda_cache", False)
self.dit_quantized = self.config.mm_config.get("mm_type", "Default") != "Default"
self.dit_quantized = self.config["mm_config"].get("mm_type", "Default") != "Default"
if self.dit_quantized:
dit_quant_scheme = self.config.mm_config.get("mm_type").split("-")[1]
if self.config.model_cls == "wan2.1_distill":
dit_quant_scheme = self.config["mm_config"].get("mm_type").split("-")[1]
if self.config["model_cls"] == "wan2.1_distill":
dit_quant_scheme = "distill_" + dit_quant_scheme
if dit_quant_scheme == "gguf":
self.dit_quantized_ckpt = find_gguf_model_path(config, "dit_quantized_ckpt", subdir=dit_quant_scheme)
self.config.use_gguf = True
self.config["use_gguf"] = True
else:
self.dit_quantized_ckpt = find_hf_model_path(
config,
......@@ -87,7 +87,7 @@ class WanModel(CompiledMethodsMixin):
self.dit_quantized_ckpt = None
assert not self.config.get("lazy_load", False)
self.weight_auto_quant = self.config.mm_config.get("weight_auto_quant", False)
self.weight_auto_quant = self.config["mm_config"].get("weight_auto_quant", False)
if self.dit_quantized:
assert self.weight_auto_quant or self.dit_quantized_ckpt is not None
......@@ -158,7 +158,7 @@ class WanModel(CompiledMethodsMixin):
weight_dict = {}
for file_path in safetensors_files:
if self.config.get("adapter_model_path", None) is not None:
if self.config.adapter_model_path == file_path:
if self.config["adapter_model_path"] == file_path:
continue
file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
weight_dict.update(file_weights)
......@@ -367,7 +367,7 @@ class WanModel(CompiledMethodsMixin):
@torch.no_grad()
def infer(self, inputs):
if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == 0 and "wan2.2_moe" not in self.config.model_cls:
if self.offload_granularity == "model" and self.scheduler.step_index == 0 and "wan2.2_moe" not in self.config["model_cls"]:
self.to_cuda()
elif self.offload_granularity != "model":
self.pre_weight.to_cuda()
......@@ -400,7 +400,7 @@ class WanModel(CompiledMethodsMixin):
self.scheduler.noise_pred = self._infer_cond_uncond(inputs, infer_condition=True)
if self.cpu_offload:
if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1 and "wan2.2_moe" not in self.config.model_cls:
if self.offload_granularity == "model" and self.scheduler.step_index == self.scheduler.infer_steps - 1 and "wan2.2_moe" not in self.config["model_cls"]:
self.to_cpu()
elif self.offload_granularity != "model":
self.pre_weight.to_cpu()
......@@ -441,7 +441,7 @@ class WanModel(CompiledMethodsMixin):
pre_infer_out.x = torch.chunk(x, world_size, dim=0)[cur_rank]
if self.config["model_cls"] in ["wan2.2", "wan2.2_audio"] and self.config["task"] == "i2v":
if self.config["model_cls"] in ["wan2.2", "wan2.2_audio"] and self.config["task"] in ["i2v", "s2v"]:
embed, embed0 = pre_infer_out.embed, pre_infer_out.embed0
padding_size = (world_size - (embed.shape[0] % world_size)) % world_size
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment