Commit ddea3722 authored by wangshankun's avatar wangshankun
Browse files

Support: SkyReels-V2 AutoRegressive Diffusion-Forcing Inference

parent aec90a0d
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
[Wan2.1-T2V-CausVid](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-CausVid) [Wan2.1-T2V-CausVid](https://huggingface.co/lightx2v/Wan2.1-T2V-14B-CausVid)
[SkyReels-V2-DF](https://huggingface.co/Skywork/SkyReels-V2-DF-14B-540P)
## Build Env With Conda ## Build Env With Conda
```shell ```shell
......
{
"infer_steps": 20,
"target_video_length": 97,
"text_len": 512,
"target_height": 544,
"target_width": 960,
"num_frames": 257,
"base_num_frames": 97,
"overlap_history": 17,
"addnoise_condition": 0,
"causal_block_size": 1,
"attention_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 5,
"sample_shift": 3,
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl",
"weight_auto_quant": true
}
}
{
"infer_steps": 30,
"target_video_length": 97,
"text_len": 512,
"target_height": 544,
"target_width": 960,
"attention_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 5,
"sample_shift": 3,
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl",
"weight_auto_quant": true
}
}
{
"infer_steps": 30,
"target_video_length": 97,
"text_len": 512,
"target_height": 544,
"target_width": 960,
"attention_type": "flash_attn3",
"seed": 42,
"sample_guide_scale": 6,
"sample_shift": 8,
"mm_config": {
"mm_type": "W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl",
"weight_auto_quant": true
}
}
{ {
"infer_steps": 50, "infer_steps": 50,
"target_video_length": 81, "target_video_length": 81,
"text_len": 512,
"target_height": 480, "target_height": 480,
"target_width": 832, "target_width": 832,
"attention_type": "flash_attn3", "attention_type": "flash_attn3",
......
...@@ -29,7 +29,7 @@ def sage_attn2(q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None ...@@ -29,7 +29,7 @@ def sage_attn2(q, k, v, cu_seqlens_q=None, cu_seqlens_kv=None, max_seqlen_q=None
) )
x = torch.cat((x1, x2), dim=1) x = torch.cat((x1, x2), dim=1)
x = x.view(max_seqlen_q, -1) x = x.view(max_seqlen_q, -1)
elif model_cls in ["wan2.1", "wan2.1_causal"]: elif model_cls in ["wan2.1", "wan2.1_causal", "wan2.1_df"]:
x = sageattn( x = sageattn(
q.unsqueeze(0), q.unsqueeze(0),
k.unsqueeze(0), k.unsqueeze(0),
......
...@@ -12,9 +12,11 @@ from lightx2v.utils.registry_factory import RUNNER_REGISTER ...@@ -12,9 +12,11 @@ from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner from lightx2v.models.runners.hunyuan.hunyuan_runner import HunyuanRunner
from lightx2v.models.runners.wan.wan_runner import WanRunner from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.runners.wan.wan_causal_runner import WanCausalRunner from lightx2v.models.runners.wan.wan_causal_runner import WanCausalRunner
from lightx2v.models.runners.wan.wan_skyreels_v2_df_runner import WanSkyreelsV2DFRunner
from lightx2v.models.runners.graph_runner import GraphRunner from lightx2v.models.runners.graph_runner import GraphRunner
from lightx2v.common.ops import * from lightx2v.common.ops import *
from loguru import logger
def init_runner(config): def init_runner(config):
...@@ -33,21 +35,21 @@ def init_runner(config): ...@@ -33,21 +35,21 @@ def init_runner(config):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causal"], default="hunyuan") parser.add_argument("--model_cls", type=str, required=True, choices=["wan2.1", "hunyuan", "wan2.1_causal", "wan2.1_skyreels_v2_df"], default="hunyuan")
parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v") parser.add_argument("--task", type=str, choices=["t2v", "i2v"], default="t2v")
parser.add_argument("--model_path", type=str, required=True) parser.add_argument("--model_path", type=str, required=True)
parser.add_argument("--config_json", type=str, required=True) parser.add_argument("--config_json", type=str, required=True)
parser.add_argument("--enable_cfg", type=bool, default=False)
parser.add_argument("--prompt", type=str, required=True) parser.add_argument("--prompt", type=str, required=True)
parser.add_argument("--negative_prompt", type=str, default="") parser.add_argument("--negative_prompt", type=str, default="")
parser.add_argument("--image_path", type=str, default="", help="The path to input image file or path for image-to-video (i2v) task") parser.add_argument("--image_path", type=str, default="", help="The path to input image file or path for image-to-video (i2v) task")
parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file") parser.add_argument("--save_video_path", type=str, default="./output_lightx2v.mp4", help="The path to save video path/file")
args = parser.parse_args() args = parser.parse_args()
print(f"args: {args}") logger.info(f"args: {args}")
with ProfilingContext("Total Cost"): with ProfilingContext("Total Cost"):
config = set_config(args) config = set_config(args)
print(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}") logger.info(f"config:\n{json.dumps(config, ensure_ascii=False, indent=4)}")
runner = init_runner(config) runner = init_runner(config)
runner.run_pipeline() runner.run_pipeline()
...@@ -12,7 +12,14 @@ class WanPostInfer: ...@@ -12,7 +12,14 @@ class WanPostInfer:
self.scheduler = scheduler self.scheduler = scheduler
def infer(self, weights, x, e, grid_sizes): def infer(self, weights, x, e, grid_sizes):
e = (weights.head_modulation + e.unsqueeze(1)).chunk(2, dim=1) if e.dim() == 2:
modulation = weights.head_modulation # 1, 2, dim
e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
elif e.dim() == 3: # For Diffustion forcing
modulation = weights.head_modulation.unsqueeze(2) # 1, 2, seq, dim
e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
e = [ei.squeeze(1) for ei in e]
norm_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6).type_as(x) norm_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6).type_as(x)
out = norm_out * (1 + e[1].squeeze(0)) + e[0].squeeze(0) out = norm_out * (1 + e[1].squeeze(0)) + e[0].squeeze(0)
x = weights.head.apply(out) x = weights.head.apply(out)
......
...@@ -27,7 +27,13 @@ class WanPreInfer: ...@@ -27,7 +27,13 @@ class WanPreInfer:
def infer(self, weights, inputs, positive): def infer(self, weights, inputs, positive):
x = [self.scheduler.latents] x = [self.scheduler.latents]
t = torch.stack([self.scheduler.timesteps[self.scheduler.step_index]])
if self.scheduler.flag_df:
t = self.scheduler.df_timesteps[self.scheduler.step_index].unsqueeze(0)
assert t.dim() == 2 # df推理模型timestep是二维
else:
t = torch.stack([self.scheduler.timesteps[self.scheduler.step_index]])
if positive: if positive:
context = inputs["text_encoder_output"]["context"] context = inputs["text_encoder_output"]["context"]
else: else:
...@@ -47,7 +53,7 @@ class WanPreInfer: ...@@ -47,7 +53,7 @@ class WanPreInfer:
assert seq_lens.max() <= seq_len assert seq_lens.max() <= seq_len
x = torch.cat([torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) for u in x]) x = torch.cat([torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) for u in x])
embed = sinusoidal_embedding_1d(self.freq_dim, t) embed = sinusoidal_embedding_1d(self.freq_dim, t.flatten())
embed = weights.time_embedding_0.apply(embed) embed = weights.time_embedding_0.apply(embed)
embed = torch.nn.functional.silu(embed) embed = torch.nn.functional.silu(embed)
embed = weights.time_embedding_2.apply(embed) embed = weights.time_embedding_2.apply(embed)
...@@ -55,6 +61,15 @@ class WanPreInfer: ...@@ -55,6 +61,15 @@ class WanPreInfer:
embed0 = weights.time_projection_1.apply(embed0).unflatten(1, (6, self.dim)) embed0 = weights.time_projection_1.apply(embed0).unflatten(1, (6, self.dim))
if self.scheduler.flag_df:
b, f = t.shape
assert b == len(x) # batch_size == 1
embed = embed.view(b, f, 1, 1, self.dim)
embed0 = embed0.view(b, f, 1, 1, 6, self.dim)
embed = embed.repeat(1, 1, grid_sizes[0][1], grid_sizes[0][2], 1).flatten(1, 3)
embed0 = embed0.repeat(1, 1, grid_sizes[0][1], grid_sizes[0][2], 1, 1).flatten(1, 3)
embed0 = embed0.transpose(1, 2).contiguous()
# text embeddings # text embeddings
stacked = torch.stack([torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context]) stacked = torch.stack([torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context])
out = weights.text_embedding_0.apply(stacked.squeeze(0)) out = weights.text_embedding_0.apply(stacked.squeeze(0))
......
...@@ -78,7 +78,13 @@ class WanTransformerInfer: ...@@ -78,7 +78,13 @@ class WanTransformerInfer:
return x return x
def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context): def infer_block(self, weights, grid_sizes, embed, x, embed0, seq_lens, freqs, context):
embed0 = (weights.modulation + embed0).chunk(6, dim=1) if embed0.dim() == 3:
modulation = weights.modulation.unsqueeze(2) # 1, 6, 1, dim
embed0 = embed0.unsqueeze(0) #
embed0 = (modulation + embed0).chunk(6, dim=1)
embed0 = [ei.squeeze(1) for ei in embed0]
elif embed0.dim() == 2:
embed0 = (weights.modulation + embed0).chunk(6, dim=1)
norm1_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6) norm1_out = torch.nn.functional.layer_norm(x, (x.shape[1],), None, None, 1e-6)
norm1_out = (norm1_out * (1 + embed0[1]) + embed0[0]).squeeze(0) norm1_out = (norm1_out * (1 + embed0[1]) + embed0[0]).squeeze(0)
......
...@@ -4,6 +4,7 @@ import torch.distributed as dist ...@@ -4,6 +4,7 @@ import torch.distributed as dist
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
from lightx2v.utils.utils import save_videos_grid, cache_video from lightx2v.utils.utils import save_videos_grid, cache_video
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from loguru import logger
class DefaultRunner: class DefaultRunner:
...@@ -32,7 +33,7 @@ class DefaultRunner: ...@@ -32,7 +33,7 @@ class DefaultRunner:
def run(self): def run(self):
for step_index in range(self.model.scheduler.infer_steps): for step_index in range(self.model.scheduler.infer_steps):
print(f"==> step_index: {step_index + 1} / {self.model.scheduler.infer_steps}") logger.info(f"==> step_index: {step_index + 1} / {self.model.scheduler.infer_steps}")
with ProfilingContext4Debug("step_pre"): with ProfilingContext4Debug("step_pre"):
self.model.scheduler.step_pre(step_index=step_index) self.model.scheduler.step_pre(step_index=step_index)
...@@ -67,7 +68,7 @@ class DefaultRunner: ...@@ -67,7 +68,7 @@ class DefaultRunner:
@ProfilingContext("Save video") @ProfilingContext("Save video")
def save_video(self, images): def save_video(self, images):
if not self.config.parallel_attn_type or (self.config.parallel_attn_type and dist.get_rank() == 0): if not self.config.parallel_attn_type or (self.config.parallel_attn_type and dist.get_rank() == 0):
if self.config.model_cls in ["wan2.1", "wan2.1_causal"]: if self.config.model_cls in ["wan2.1", "wan2.1_causal", "wan2.1_skyreels_v2_df"]:
cache_video(tensor=images, save_file=self.config.save_video_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1)) cache_video(tensor=images, save_file=self.config.save_video_path, fps=16, nrow=1, normalize=True, value_range=(-1, 1))
else: else:
save_videos_grid(images, self.config.save_video_path, fps=24) save_videos_grid(images, self.config.save_video_path, fps=24)
......
import os
import gc
import numpy as np
import torch
import torchvision.transforms.functional as TF
from PIL import Image
from lightx2v.utils.registry_factory import RUNNER_REGISTER
from lightx2v.models.runners.wan.wan_runner import WanRunner
from lightx2v.models.schedulers.wan.df.skyreels_v2_df_scheduler import WanSkyreelsV2DFScheduler
from lightx2v.utils.profiler import ProfilingContext4Debug, ProfilingContext
from lightx2v.utils.profiler import ProfilingContext
import torch.distributed as dist
from loguru import logger
@RUNNER_REGISTER("wan2.1_skyreels_v2_df")
class WanSkyreelsV2DFRunner(WanRunner): # Diffustion foring for SkyReelsV2 DF I2V/T2V
def __init__(self, config):
super().__init__(config)
def init_scheduler(self):
scheduler = WanSkyreelsV2DFScheduler(self.config)
self.model.set_scheduler(scheduler)
def run_image_encoder(self, config, image_encoder, vae_model):
img = Image.open(config.image_path).convert("RGB")
img = TF.to_tensor(img).sub_(0.5).div_(0.5).cuda()
h, w = img.shape[1:]
aspect_ratio = h / w
max_area = config.target_height * config.target_width
lat_h = round(np.sqrt(max_area * aspect_ratio) // config.vae_stride[1] // config.patch_size[1] * config.patch_size[1])
lat_w = round(np.sqrt(max_area / aspect_ratio) // config.vae_stride[2] // config.patch_size[2] * config.patch_size[2])
h = lat_h * config.vae_stride[1]
w = lat_w * config.vae_stride[2]
config.lat_h = lat_h
config.lat_w = lat_w
vae_encode_out = vae_model.encode([torch.nn.functional.interpolate(img[None].cpu(), size=(h, w), mode="bicubic").transpose(0, 1).cuda()], config)[0]
vae_encode_out = vae_encode_out.to(torch.bfloat16)
return vae_encode_out
def set_target_shape(self):
if os.path.isfile(self.config.image_path):
self.config.target_shape = (16, (self.config.target_video_length - 1) // 4 + 1, self.config.lat_h, self.config.lat_w)
else:
self.config.target_shape = (
16,
(self.config.target_video_length - 1) // 4 + 1,
int(self.config.target_height) // self.config.vae_stride[1],
int(self.config.target_width) // self.config.vae_stride[2],
)
def run_input_encoder(self):
image_encoder_output = None
if os.path.isfile(self.config.image_path):
with ProfilingContext("Run Img Encoder"):
image_encoder_output = self.run_image_encoder(self.config, self.image_encoder, self.vae_model)
with ProfilingContext("Run Text Encoder"):
text_encoder_output = self.run_text_encoder(self.config["prompt"], self.text_encoders, self.config, image_encoder_output)
self.set_target_shape()
self.inputs = {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output}
gc.collect()
torch.cuda.empty_cache()
def run(self):
num_frames = self.config.num_frames
overlap_history = self.config.overlap_history
base_num_frames = self.config.base_num_frames
addnoise_condition = self.config.addnoise_condition
causal_block_size = self.config.causal_block_size
latent_length = (num_frames - 1) // 4 + 1
base_num_frames = (base_num_frames - 1) // 4 + 1 if base_num_frames is not None else latent_length
overlap_history_frames = (overlap_history - 1) // 4 + 1
n_iter = 1 + (latent_length - base_num_frames - 1) // (base_num_frames - overlap_history_frames) + 1
prefix_video = self.inputs["image_encoder_output"]
predix_video_latent_length = 0
if prefix_video is not None:
predix_video_latent_length = prefix_video.size(1)
output_video = None
logger.info(f"Diffusion-Forcing n_iter:{n_iter}")
for i in range(n_iter):
if output_video is not None: # i !=0
prefix_video = output_video[:, :, -overlap_history:].to(self.model.scheduler.device)
prefix_video = self.vae_model.encode(prefix_video, self.config)[0] # [(b, c, f, h, w)]
if prefix_video.shape[1] % causal_block_size != 0:
truncate_len = prefix_video.shape[1] % causal_block_size
# the length of prefix video is truncated for the casual block size alignment.
prefix_video = prefix_video[:, : prefix_video.shape[1] - truncate_len]
predix_video_latent_length = prefix_video.shape[1]
finished_frame_num = i * (base_num_frames - overlap_history_frames) + overlap_history_frames
left_frame_num = latent_length - finished_frame_num
base_num_frames_iter = min(left_frame_num + overlap_history_frames, base_num_frames)
else: # i == 0
base_num_frames_iter = base_num_frames
if prefix_video is not None:
input_dtype = self.model.scheduler.latents.dtype
self.model.scheduler.latents[:, :predix_video_latent_length] = prefix_video.to(input_dtype)
self.model.scheduler.generate_timestep_matrix(base_num_frames_iter, base_num_frames_iter, addnoise_condition, predix_video_latent_length, causal_block_size)
for step_index in range(self.model.scheduler.infer_steps):
logger.info(f"==> step_index: {step_index + 1} / {self.model.scheduler.infer_steps}")
with ProfilingContext4Debug("step_pre"):
self.model.scheduler.step_pre(step_index=step_index)
with ProfilingContext4Debug("infer"):
self.model.infer(self.inputs)
with ProfilingContext4Debug("step_post"):
self.model.scheduler.step_post()
videos = self.run_vae(self.model.scheduler.latents, self.model.scheduler.generator)
self.model.scheduler.prepare(self.inputs["image_encoder_output"]) # reset
if output_video is None:
output_video = videos.clamp(-1, 1).cpu() # b, c, f, h, w
else:
output_video = torch.cat([output_video, videos[:, :, overlap_history:].clamp(-1, 1).cpu()], 2)
return output_video
def run_pipeline(self):
self.init_scheduler()
self.run_input_encoder()
self.model.scheduler.prepare()
output_video = self.run()
self.end_run()
self.save_video(output_video)
...@@ -6,6 +6,7 @@ class BaseScheduler: ...@@ -6,6 +6,7 @@ class BaseScheduler:
self.config = config self.config = config
self.step_index = 0 self.step_index = 0
self.latents = None self.latents = None
self.flag_df = False
def step_pre(self, step_index): def step_pre(self, step_index):
self.step_index = step_index self.step_index = step_index
......
import os
import math
import numpy as np
import torch
from lightx2v.models.schedulers.wan.scheduler import WanScheduler
class WanSkyreelsV2DFScheduler(WanScheduler):
def __init__(self, config):
super().__init__(config)
self.df_schedulers = []
self.flag_df = True
def prepare(self, image_encoder_output=None):
self.generator = torch.Generator(device=self.device)
self.generator.manual_seed(self.config.seed)
self.prepare_latents(self.config.target_shape, dtype=torch.float32)
if os.path.isfile(self.config.image_path):
self.seq_len = ((self.config.target_video_length - 1) // self.config.vae_stride[0] + 1) * self.config.lat_h * self.config.lat_w // (self.config.patch_size[1] * self.config.patch_size[2])
else:
self.seq_len = math.ceil((self.config.target_shape[2] * self.config.target_shape[3]) / (self.config.patch_size[1] * self.config.patch_size[2]) * self.config.target_shape[1])
alphas = np.linspace(1, 1 / self.num_train_timesteps, self.num_train_timesteps)[::-1].copy()
sigmas = 1.0 - alphas
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32)
sigmas = self.shift * sigmas / (1 + (self.shift - 1) * sigmas)
self.sigmas = sigmas
self.timesteps = sigmas * self.num_train_timesteps
self.model_outputs = [None] * self.solver_order
self.timestep_list = [None] * self.solver_order
self.last_sample = None
self.sigmas = self.sigmas.to("cpu")
self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item()
self.set_timesteps(self.infer_steps, device=self.device, shift=self.sample_shift)
def generate_timestep_matrix(
self,
num_frames,
base_num_frames,
addnoise_condition,
num_pre_ready,
casual_block_size=1,
ar_step=0,
shrink_interval_with_mask=False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[tuple]]:
self.addnoise_condition = addnoise_condition
self.predix_video_latent_length = num_pre_ready
step_template = self.timesteps
step_matrix, step_index = [], []
update_mask, valid_interval = [], []
num_iterations = len(step_template) + 1
num_frames_block = num_frames // casual_block_size
base_num_frames_block = base_num_frames // casual_block_size
if base_num_frames_block < num_frames_block:
infer_step_num = len(step_template)
gen_block = base_num_frames_block
min_ar_step = infer_step_num / gen_block
assert ar_step >= min_ar_step, f"ar_step should be at least {math.ceil(min_ar_step)} in your setting"
# print(num_frames, step_template, base_num_frames, ar_step, num_pre_ready, casual_block_size, num_frames_block, base_num_frames_block)
step_template = torch.cat(
[
torch.tensor([999], dtype=torch.int64, device=step_template.device),
step_template.long(),
torch.tensor([0], dtype=torch.int64, device=step_template.device),
]
) # to handle the counter in row works starting from 1
pre_row = torch.zeros(num_frames_block, dtype=torch.long)
if num_pre_ready > 0:
pre_row[: num_pre_ready // casual_block_size] = num_iterations
while not torch.all(pre_row >= (num_iterations - 1)):
new_row = torch.zeros(num_frames_block, dtype=torch.long)
for i in range(num_frames_block):
if i == 0 or pre_row[i - 1] >= (num_iterations - 1): # the first frame or the last frame is completely denoised
new_row[i] = pre_row[i] + 1
else:
new_row[i] = new_row[i - 1] - ar_step
new_row = new_row.clamp(0, num_iterations)
update_mask.append((new_row != pre_row) & (new_row != num_iterations)) # False: no need to update, True: need to update
step_index.append(new_row)
step_matrix.append(step_template[new_row])
pre_row = new_row
# for long video we split into several sequences, base_num_frames is set to the model max length (for training)
terminal_flag = base_num_frames_block
if shrink_interval_with_mask:
idx_sequence = torch.arange(num_frames_block, dtype=torch.int64)
update_mask = update_mask[0]
update_mask_idx = idx_sequence[update_mask]
last_update_idx = update_mask_idx[-1].item()
terminal_flag = last_update_idx + 1
# for i in range(0, len(update_mask)):
for curr_mask in update_mask:
if terminal_flag < num_frames_block and curr_mask[terminal_flag]:
terminal_flag += 1
valid_interval.append((max(terminal_flag - base_num_frames_block, 0), terminal_flag))
step_update_mask = torch.stack(update_mask, dim=0)
step_index = torch.stack(step_index, dim=0)
step_matrix = torch.stack(step_matrix, dim=0)
if casual_block_size > 1:
step_update_mask = step_update_mask.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
step_index = step_index.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
step_matrix = step_matrix.unsqueeze(-1).repeat(1, 1, casual_block_size).flatten(1).contiguous()
valid_interval = [(s * casual_block_size, e * casual_block_size) for s, e in valid_interval]
self.step_matrix = step_matrix
self.step_update_mask = step_update_mask
self.valid_interval = valid_interval
self.df_timesteps = torch.zeros_like(self.step_matrix)
self.df_schedulers = []
for _ in range(base_num_frames):
sample_scheduler = WanScheduler(self.config)
sample_scheduler.prepare()
self.df_schedulers.append(sample_scheduler)
def step_pre(self, step_index):
self.step_index = step_index
self.latents = self.latents.to(dtype=torch.bfloat16)
valid_interval_start, valid_interval_end = self.valid_interval[step_index]
timestep = self.step_matrix[step_index][None, valid_interval_start:valid_interval_end].clone()
if self.addnoise_condition > 0 and valid_interval_start < self.predix_video_latent_length:
latent_model_input = self.latents[:, valid_interval_start:valid_interval_end, :, :].clone()
noise_factor = 0.001 * self.addnoise_condition
self.latents[:, valid_interval_start : self.predix_video_latent_length] = (
latent_model_input[:, valid_interval_start : self.predix_video_latent_length] * (1.0 - noise_factor)
+ torch.randn_like(latent_model_input[:, valid_interval_start : self.predix_video_latent_length]) * noise_factor
)
timestep[:, valid_interval_start : self.predix_video_latent_length] = self.addnoise_condition
self.df_timesteps[step_index] = timestep
def step_post(self):
update_mask_i = self.step_update_mask[self.step_index]
valid_interval_start, valid_interval_end = self.valid_interval[self.step_index]
timestep = self.df_timesteps[self.step_index]
for idx in range(valid_interval_start, valid_interval_end): # 每一帧单独step
if update_mask_i[idx].item():
self.df_schedulers[idx].step_pre(step_index=self.step_index)
self.df_schedulers[idx].noise_pred = self.noise_pred[:, idx - valid_interval_start]
self.df_schedulers[idx].timesteps[self.step_index] = timestep[idx]
self.df_schedulers[idx].latents = self.latents[:, idx]
self.df_schedulers[idx].step_post()
self.latents[:, idx] = self.df_schedulers[idx].latents
...@@ -18,7 +18,7 @@ class WanScheduler(BaseScheduler): ...@@ -18,7 +18,7 @@ class WanScheduler(BaseScheduler):
self.solver_order = 2 self.solver_order = 2
self.noise_pred = None self.noise_pred = None
def prepare(self, image_encoder_output): def prepare(self, image_encoder_output=None):
self.generator = torch.Generator(device=self.device) self.generator = torch.Generator(device=self.device)
self.generator.manual_seed(self.config.seed) self.generator.manual_seed(self.config.seed)
......
#!/bin/bash
# set path and first
lightx2v_path="/mnt/Text2Video/wangshankun/tmp_code/lightx2v/"
model_path="/mnt/Text2Video/wangshankun/HF_Cache/hub/models--Skywork--SkyReels-V2-DF-14B-540P/snapshots/7ff972ba7b6a33d2f6e6c976dd3cf2d36984eee4/"
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
#I2V
python -m lightx2v.infer \
--model_cls wan2.1_skyreels_v2_df \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_skyreels_v2_df.json \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_skyreels_v2_df.mp4
#T2V
#python -m lightx2v.infer \
#--model_cls wan2.1_df \
#--task t2v \
#--model_path $model_path \
#--config_json ${lightx2v_path}/configs/wan_skyreels_v2_df.json \
#--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
#--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
#--save_video_path ${lightx2v_path}/save_results/output_lightx2v_skyreels_v2_df.mp4
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python -m lightx2v.infer \
--model_cls wan2.1 \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_skyreels_v2_i2v.json \
--prompt "Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard. The fluffy-furred feline gazes directly at the camera with a relaxed expression. Blurred beach scenery forms the background featuring crystal-clear waters, distant green hills, and a blue sky dotted with white clouds. The cat assumes a naturally relaxed posture, as if savoring the sea breeze and warm sunlight. A close-up shot highlights the feline's intricate details and the refreshing atmosphere of the seaside." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--image_path ${lightx2v_path}/assets/inputs/imgs/img_0.jpg \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_skyreels_v2_i2v.mp4
#!/bin/bash
# set path and first
lightx2v_path=
model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using defalt value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export ENABLE_PROFILING_DEBUG=true
export ENABLE_GRAPH_MODE=false
python -m lightx2v.infer \
--model_cls wan2.1 \
--task t2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/wan_skyreels_v2_t2v.json \
--prompt "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." \
--negative_prompt 色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走 \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_skyreels_v2_t2v.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment