Unverified Commit 9826b8ca authored by Watebear's avatar Watebear Committed by GitHub
Browse files

[feat]: support matrix game2 universal, gta_drive, templerun & streaming mode

parent 44e215f3
import os
from safetensors import safe_open
from lightx2v.common.modules.weight_module import WeightModule, WeightModuleList
from lightx2v.models.networks.wan.weights.transformer_weights import (
WanFFN,
WanSelfAttention,
WanTransformerAttentionBlock,
)
from lightx2v.utils.registry_factory import (
ATTN_WEIGHT_REGISTER,
LN_WEIGHT_REGISTER,
MM_WEIGHT_REGISTER,
RMS_WEIGHT_REGISTER,
TENSOR_REGISTER,
)
class WanActionTransformerWeights(WeightModule):
def __init__(self, config):
super().__init__()
self.blocks_num = config["num_layers"]
self.task = config["task"]
self.config = config
self.mm_type = config.get("dit_quant_scheme", "Default")
if self.mm_type != "Default":
assert config.get("dit_quantized") is True
action_blocks = config["action_config"]["blocks"]
block_list = []
for i in range(self.blocks_num):
if i in action_blocks:
block_list.append(WanTransformerActionBlock(i, self.task, self.mm_type, self.config, "blocks"))
else:
block_list.append(WanTransformerAttentionBlock(i, self.task, self.mm_type, self.config, "blocks"))
self.blocks = WeightModuleList(block_list)
self.add_module("blocks", self.blocks)
# non blocks weights
self.register_parameter("norm", LN_WEIGHT_REGISTER["Default"]())
self.add_module("head", MM_WEIGHT_REGISTER["Default"]("head.head.weight", "head.head.bias"))
self.register_parameter("head_modulation", TENSOR_REGISTER["Default"]("head.modulation"))
def clear(self):
for block in self.blocks:
for phase in block.compute_phases:
phase.clear()
def non_block_weights_to_cuda(self):
self.norm.to_cuda()
self.head.to_cuda()
self.head_modulation.to_cuda()
def non_block_weights_to_cpu(self):
self.norm.to_cpu()
self.head.to_cpu()
self.head_modulation.to_cpu()
class WanTransformerActionBlock(WeightModule):
def __init__(self, block_index, task, mm_type, config, block_prefix="blocks"):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
self.task = task
self.config = config
self.quant_method = config.get("quant_method", None)
self.lazy_load = self.config.get("lazy_load", False)
if self.lazy_load:
lazy_load_path = os.path.join(self.config["dit_quantized_ckpt"], f"block_{block_index}.safetensors")
self.lazy_load_file = safe_open(lazy_load_path, framework="pt", device="cpu")
else:
self.lazy_load_file = None
self.compute_phases = WeightModuleList(
[
WanSelfAttention(
block_index,
block_prefix,
task,
mm_type,
config,
self.lazy_load,
self.lazy_load_file,
),
WanActionCrossAttention(
block_index,
block_prefix,
task,
mm_type,
config,
self.lazy_load,
self.lazy_load_file,
),
WanActionModule(
block_index,
block_prefix,
task,
mm_type,
config,
self.lazy_load,
self.lazy_load_file,
),
WanFFN(
block_index,
block_prefix,
task,
mm_type,
config,
self.lazy_load,
self.lazy_load_file,
),
]
)
self.add_module("compute_phases", self.compute_phases)
class WanActionModule(WeightModule):
def __init__(self, block_index, block_prefix, task, mm_type, config, lazy_load, lazy_load_file):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
self.task = task
self.config = config
self.quant_method = config.get("quant_method", None)
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.attn_rms_type = "self_forcing"
self.add_module(
"keyboard_embed_0",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.action_model.keyboard_embed.0.weight",
f"{block_prefix}.{self.block_index}.action_model.keyboard_embed.0.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"keyboard_embed_2",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.action_model.keyboard_embed.2.weight",
f"{block_prefix}.{self.block_index}.action_model.keyboard_embed.2.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"proj_keyboard",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.action_model.proj_keyboard.weight",
bias_name=None,
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
),
)
self.add_module(
"keyboard_attn_kv",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.action_model.keyboard_attn_kv.weight",
bias_name=None,
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
),
)
self.add_module("cross_attn_2", ATTN_WEIGHT_REGISTER[self.config["cross_attn_2_type"]]())
self.add_module(
"mouse_attn_q",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.action_model.mouse_attn_q.weight",
bias_name=None,
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
),
)
if self.config["mode"] != "templerun":
self.add_module(
"t_qkv",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.action_model.t_qkv.weight",
bias_name=None,
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
),
)
self.add_module(
"proj_mouse",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.action_model.proj_mouse.weight",
bias_name=None,
lazy_load=self.lazy_load,
lazy_load_file=self.lazy_load_file,
),
)
self.add_module(
"mouse_mlp_0",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.action_model.mouse_mlp.0.weight",
f"{block_prefix}.{self.block_index}.action_model.mouse_mlp.0.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"mouse_mlp_2",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.action_model.mouse_mlp.2.weight",
f"{block_prefix}.{self.block_index}.action_model.mouse_mlp.2.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"mouse_mlp_3",
LN_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.action_model.mouse_mlp.3.weight",
f"{block_prefix}.{self.block_index}.action_model.mouse_mlp.3.bias",
self.lazy_load,
self.lazy_load_file,
eps=1e-6,
),
)
class WanActionCrossAttention(WeightModule):
def __init__(self, block_index, block_prefix, task, mm_type, config, lazy_load, lazy_load_file):
super().__init__()
self.block_index = block_index
self.mm_type = mm_type
self.task = task
self.config = config
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
if self.config.get("sf_config", False):
self.attn_rms_type = "self_forcing"
else:
self.attn_rms_type = "sgl-kernel"
self.add_module(
"norm3",
LN_WEIGHT_REGISTER["Default"](
f"{block_prefix}.{self.block_index}.norm3.weight",
f"{block_prefix}.{self.block_index}.norm3.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"cross_attn_q",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.cross_attn.q.weight",
f"{block_prefix}.{self.block_index}.cross_attn.q.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"cross_attn_k",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.cross_attn.k.weight",
f"{block_prefix}.{self.block_index}.cross_attn.k.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"cross_attn_v",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.cross_attn.v.weight",
f"{block_prefix}.{self.block_index}.cross_attn.v.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"cross_attn_o",
MM_WEIGHT_REGISTER[self.mm_type](
f"{block_prefix}.{self.block_index}.cross_attn.o.weight",
f"{block_prefix}.{self.block_index}.cross_attn.o.bias",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"cross_attn_norm_q",
RMS_WEIGHT_REGISTER[self.attn_rms_type](
f"{block_prefix}.{self.block_index}.cross_attn.norm_q.weight",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module(
"cross_attn_norm_k",
RMS_WEIGHT_REGISTER[self.attn_rms_type](
f"{block_prefix}.{self.block_index}.cross_attn.norm_k.weight",
self.lazy_load,
self.lazy_load_file,
),
)
self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["cross_attn_1_type"]]())
......@@ -303,7 +303,7 @@ class WanCrossAttention(WeightModule):
)
self.add_module("cross_attn_1", ATTN_WEIGHT_REGISTER[self.config["cross_attn_1_type"]]())
if self.config["task"] in ["i2v", "flf2v", "animate", "s2v"] and self.config.get("use_image_encoder", True):
if self.config["task"] in ["i2v", "flf2v", "animate", "s2v"] and self.config.get("use_image_encoder", True) and self.config["model_cls"] != "wan2.1_sf_mtxg2":
self.add_module(
"cross_attn_k_img",
MM_WEIGHT_REGISTER[self.mm_type](
......
import os
import torch
from diffusers.utils import load_image
from torchvision.transforms import v2
from lightx2v.models.input_encoders.hf.wan.matrix_game2.clip import CLIPModel
from lightx2v.models.input_encoders.hf.wan.matrix_game2.conditions import Bench_actions_gta_drive, Bench_actions_templerun, Bench_actions_universal
from lightx2v.models.networks.wan.matrix_game2_model import WanSFMtxg2Model
from lightx2v.models.runners.wan.wan_sf_runner import WanSFRunner
from lightx2v.models.video_encoders.hf.wan.vae_sf import WanMtxg2VAE
from lightx2v.server.metrics import monitor_cli
from lightx2v.utils.envs import *
from lightx2v.utils.profiler import *
from lightx2v.utils.registry_factory import RUNNER_REGISTER
class VAEWrapper:
def __init__(self, vae):
self.vae = vae
def __getattr__(self, name):
if name in self.__dict__:
return self.__dict__[name]
else:
return getattr(self.vae, name)
def encode(self, x):
raise NotImplementedError
def decode(self, latents):
return NotImplementedError
class WanxVAEWrapper(VAEWrapper):
def __init__(self, vae, clip):
self.vae = vae
self.vae.requires_grad_(False)
self.vae.eval()
self.clip = clip
if clip is not None:
self.clip.requires_grad_(False)
self.clip.eval()
def encode(self, x, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
x = self.vae.encode(x, device=device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride) # already scaled
return x # torch.stack(x, dim=0)
def clip_img(self, x):
x = self.clip(x)
return x
def decode(self, latents, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
videos = self.vae.decode(latents, device=device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
return videos # self.vae.decode(videos, dim=0) # already scaled
def to(self, device, dtype):
# 移动 vae 到指定设备
self.vae = self.vae.to(device, dtype)
# 如果 clip 存在,也移动到指定设备
if self.clip is not None:
self.clip = self.clip.to(device, dtype)
return self
def get_wanx_vae_wrapper(model_path, weight_dtype):
vae = WanMtxg2VAE(pretrained_path=os.path.join(model_path, "Wan2.1_VAE.pth")).to(weight_dtype)
clip = CLIPModel(checkpoint_path=os.path.join(model_path, "models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth"), tokenizer_path=os.path.join(model_path, "xlm-roberta-large"))
return WanxVAEWrapper(vae, clip)
def get_current_action(mode="universal"):
CAM_VALUE = 0.1
if mode == "universal":
print()
print("-" * 30)
print("PRESS [I, K, J, L, U] FOR CAMERA TRANSFORM\n (I: up, K: down, J: left, L: right, U: no move)")
print("PRESS [W, S, A, D, Q] FOR MOVEMENT\n (W: forward, S: back, A: left, D: right, Q: no move)")
print("-" * 30)
CAMERA_VALUE_MAP = {"i": [CAM_VALUE, 0], "k": [-CAM_VALUE, 0], "j": [0, -CAM_VALUE], "l": [0, CAM_VALUE], "u": [0, 0]}
KEYBOARD_IDX = {"w": [1, 0, 0, 0], "s": [0, 1, 0, 0], "a": [0, 0, 1, 0], "d": [0, 0, 0, 1], "q": [0, 0, 0, 0]}
flag = 0
while flag != 1:
try:
idx_mouse = input("Please input the mouse action (e.g. `U`):\n").strip().lower()
idx_keyboard = input("Please input the keyboard action (e.g. `W`):\n").strip().lower()
if idx_mouse in CAMERA_VALUE_MAP.keys() and idx_keyboard in KEYBOARD_IDX.keys():
flag = 1
except Exception as e:
pass
mouse_cond = torch.tensor(CAMERA_VALUE_MAP[idx_mouse]).cuda()
keyboard_cond = torch.tensor(KEYBOARD_IDX[idx_keyboard]).cuda()
elif mode == "gta_drive":
print()
print("-" * 30)
print("PRESS [W, S, A, D, Q] FOR MOVEMENT\n (W: forward, S: back, A: left, D: right, Q: no move)")
print("-" * 30)
CAMERA_VALUE_MAP = {"a": [0, -CAM_VALUE], "d": [0, CAM_VALUE], "q": [0, 0]}
KEYBOARD_IDX = {"w": [1, 0], "s": [0, 1], "q": [0, 0]}
flag = 0
while flag != 1:
try:
indexes = input("Please input the actions (split with ` `):\n(e.g. `W` for forward, `W A` for forward and left)\n").strip().lower().split(" ")
idx_mouse = []
idx_keyboard = []
for i in indexes:
if i in CAMERA_VALUE_MAP.keys():
idx_mouse += [i]
elif i in KEYBOARD_IDX.keys():
idx_keyboard += [i]
if len(idx_mouse) == 0:
idx_mouse += ["q"]
if len(idx_keyboard) == 0:
idx_keyboard += ["q"]
assert idx_mouse in [["a"], ["d"], ["q"]] and idx_keyboard in [["q"], ["w"], ["s"]]
flag = 1
except Exception as e:
pass
mouse_cond = torch.tensor(CAMERA_VALUE_MAP[idx_mouse[0]]).cuda()
keyboard_cond = torch.tensor(KEYBOARD_IDX[idx_keyboard[0]]).cuda()
elif mode == "templerun":
print()
print("-" * 30)
print("PRESS [W, S, A, D, Z, C, Q] FOR ACTIONS\n (W: jump, S: slide, A: left side, D: right side, Z: turn left, C: turn right, Q: no move)")
print("-" * 30)
KEYBOARD_IDX = {
"w": [0, 1, 0, 0, 0, 0, 0],
"s": [0, 0, 1, 0, 0, 0, 0],
"a": [0, 0, 0, 0, 0, 1, 0],
"d": [0, 0, 0, 0, 0, 0, 1],
"z": [0, 0, 0, 1, 0, 0, 0],
"c": [0, 0, 0, 0, 1, 0, 0],
"q": [1, 0, 0, 0, 0, 0, 0],
}
flag = 0
while flag != 1:
try:
idx_keyboard = input("Please input the action: \n(e.g. `W` for forward, `Z` for turning left)\n").strip().lower()
if idx_keyboard in KEYBOARD_IDX.keys():
flag = 1
except Exception as e:
pass
keyboard_cond = torch.tensor(KEYBOARD_IDX[idx_keyboard]).cuda()
if mode != "templerun":
return {"mouse": mouse_cond, "keyboard": keyboard_cond}
return {"keyboard": keyboard_cond}
@RUNNER_REGISTER("wan2.1_sf_mtxg2")
class WanSFMtxg2Runner(WanSFRunner):
def __init__(self, config):
super().__init__(config)
self.frame_process = v2.Compose(
[
v2.Resize(size=(352, 640), antialias=True),
v2.ToTensor(),
v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
]
)
self.device = torch.device("cuda")
self.weight_dtype = torch.bfloat16
def load_text_encoder(self):
from lightx2v.models.input_encoders.hf.wan.matrix_game2.conditions import MatrixGame2_Bench
return MatrixGame2_Bench()
def load_image_encoder(self):
wrapper = get_wanx_vae_wrapper(self.config["model_path"], torch.float16)
wrapper.requires_grad_(False)
wrapper.eval()
return wrapper.to(self.device, self.weight_dtype)
def _resizecrop(self, image, th, tw):
w, h = image.size
if h / w > th / tw:
new_w = int(w)
new_h = int(new_w * th / tw)
else:
new_h = int(h)
new_w = int(new_h * tw / th)
left = (w - new_w) / 2
top = (h - new_h) / 2
right = (w + new_w) / 2
bottom = (h + new_h) / 2
image = image.crop((left, top, right, bottom))
return image
@ProfilingContext4DebugL2("Run Encoders")
def _run_input_encoder_local_i2v(self):
# image
image = load_image(self.input_info.image_path)
image = self._resizecrop(image, 352, 640)
image = self.frame_process(image)[None, :, None, :, :].to(dtype=self.weight_dtype, device=self.device)
padding_video = torch.zeros_like(image).repeat(1, 1, 4 * (self.config["num_output_frames"] - 1), 1, 1)
img_cond = torch.concat([image, padding_video], dim=2)
tiler_kwargs = {"tiled": True, "tile_size": [44, 80], "tile_stride": [23, 38]}
img_cond = self.image_encoder.encode(img_cond, device=self.device, **tiler_kwargs).to(self.device)
mask_cond = torch.ones_like(img_cond)
mask_cond[:, :, 1:] = 0
cond_concat = torch.cat([mask_cond[:, :4], img_cond], dim=1)
visual_context = self.image_encoder.clip.encode_video(image)
image_encoder_output = {"cond_concat": cond_concat.to(device=self.device, dtype=self.weight_dtype), "visual_context": visual_context.to(device=self.device, dtype=self.weight_dtype)}
# text
text_encoder_output = {}
num_frames = (self.config["num_output_frames"] - 1) * 4 + 1
if self.config["mode"] == "universal":
cond_data = Bench_actions_universal(num_frames)
mouse_condition = cond_data["mouse_condition"].unsqueeze(0).to(device=self.device, dtype=self.weight_dtype)
text_encoder_output["mouse_cond"] = mouse_condition
elif self.config["mode"] == "gta_drive":
cond_data = Bench_actions_gta_drive(num_frames)
mouse_condition = cond_data["mouse_condition"].unsqueeze(0).to(device=self.device, dtype=self.weight_dtype)
text_encoder_output["mouse_cond"] = mouse_condition
else:
cond_data = Bench_actions_templerun(num_frames)
keyboard_condition = cond_data["keyboard_condition"].unsqueeze(0).to(device=self.device, dtype=self.weight_dtype)
text_encoder_output["keyboard_cond"] = keyboard_condition
# set shape
self.input_info.latent_shape = [16, self.config["num_output_frames"], 44, 80]
return {"text_encoder_output": text_encoder_output, "image_encoder_output": image_encoder_output}
def load_transformer(self):
model = WanSFMtxg2Model(
self.config["model_path"],
self.config,
self.init_device,
)
return model
def init_run_segment(self, segment_idx):
self.segment_idx = segment_idx
if self.config["streaming"]:
self.inputs["current_actions"] = get_current_action(mode=self.config["mode"])
@ProfilingContext4DebugL2("Run DiT")
def run_main(self, total_steps=None):
self.init_run()
if self.config.get("compile", False):
self.model.select_graph_for_compile(self.input_info)
stop = ""
while stop != "n":
for segment_idx in range(self.video_segment_num):
logger.info(f"🔄 start segment {segment_idx + 1}/{self.video_segment_num}")
with ProfilingContext4DebugL1(
f"segment end2end {segment_idx + 1}/{self.video_segment_num}",
recorder_mode=GET_RECORDER_MODE(),
metrics_func=monitor_cli.lightx2v_run_segments_end2end_duration,
metrics_labels=["DefaultRunner"],
):
self.check_stop()
# 1. default do nothing
self.init_run_segment(segment_idx)
# 2. main inference loop
latents = self.run_segment(total_steps=total_steps)
# 3. vae decoder
self.gen_video = self.run_vae_decoder(latents)
# 4. default do nothing
self.end_run_segment(segment_idx)
# 5. stop or not
if self.config["streaming"]:
stop = input("Press `n` to stop generation: ").strip().lower()
if stop == "n":
break
stop = "n"
gen_video_final = self.process_images_after_vae_decoder()
self.end_run()
return gen_video_final
......@@ -720,6 +720,23 @@ class WanVAE_(nn.Module):
self.clear_cache()
return out
def cached_decode(self, z, scale):
# z: [b,c,t,h,w]
if isinstance(scale[0], torch.Tensor):
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(1, self.z_dim, 1, 1, 1)
else:
z = z / scale[1] + scale[0]
iter_ = z.shape[2]
x = self.conv2(z)
for i in range(iter_):
self._conv_idx = [0]
if i == 0:
out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
else:
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
out = torch.cat([out, out_], 2)
return out
def reparameterize(self, mu, log_var):
std = torch.exp(0.5 * log_var)
eps = torch.randn_like(std)
......
import torch
import torch.nn as nn
from einops import rearrange, repeat
from lightx2v.models.video_encoders.hf.wan.vae import _video_vae
from lightx2v.models.video_encoders.hf.wan.vae import WanVAE_, _video_vae
class WanSFVAE:
......@@ -15,6 +17,7 @@ class WanSFVAE:
cpu_offload=False,
use_2d_split=True,
load_from_rank0=False,
**kwargs,
):
self.dtype = dtype
self.device = device
......@@ -27,10 +30,12 @@ class WanSFVAE:
std = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160]
self.mean = torch.tensor(mean, dtype=torch.float32)
self.std = torch.tensor(std, dtype=torch.float32)
self.scale = [self.mean, 1.0 / self.std]
# init model
self.model = _video_vae(pretrained_path=vae_path, z_dim=z_dim, cpu_offload=cpu_offload, dtype=dtype, load_from_rank0=load_from_rank0).eval().requires_grad_(False).to(device).to(dtype)
self.model.clear_cache()
self.upsampling_factor = 8
def to_cpu(self):
self.model.encoder = self.model.encoder.to("cpu")
......@@ -72,3 +77,269 @@ class WanSFVAE:
# to [batch_size, num_frames, num_channels, height, width]
output = output.permute(0, 2, 1, 3, 4).squeeze(0)
return output
def tiled_encode(self, video, device, tile_size, tile_stride):
_, _, T, H, W = video.shape
size_h, size_w = tile_size
stride_h, stride_w = tile_stride
# Split tasks
tasks = []
for h in range(0, H, stride_h):
if h - stride_h >= 0 and h - stride_h + size_h >= H:
continue
for w in range(0, W, stride_w):
if w - stride_w >= 0 and w - stride_w + size_w >= W:
continue
h_, w_ = h + size_h, w + size_w
tasks.append((h, h_, w, w_))
data_device = "cpu"
computation_device = device
out_T = (T + 3) // 4
weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
for h, h_, w, w_ in tasks: # tqdm(tasks, desc="VAE encoding"):
hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)
hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device)
mask = self.build_mask(
hidden_states_batch, is_bound=(h == 0, h_ >= H, w == 0, w_ >= W), border_width=((size_h - stride_h) // self.upsampling_factor, (size_w - stride_w) // self.upsampling_factor)
).to(dtype=video.dtype, device=data_device)
target_h = h // self.upsampling_factor
target_w = w // self.upsampling_factor
values[
:,
:,
:,
target_h : target_h + hidden_states_batch.shape[3],
target_w : target_w + hidden_states_batch.shape[4],
] += hidden_states_batch * mask
weight[
:,
:,
:,
target_h : target_h + hidden_states_batch.shape[3],
target_w : target_w + hidden_states_batch.shape[4],
] += mask
values = values / weight
return values
def single_encode(self, video, device):
video = video.to(device)
x = self.model.encode(video, self.scale)
return x
def encode(self, videos, device, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
videos = [video.to("cpu") for video in videos]
hidden_states = []
for video in videos:
video = video.unsqueeze(0)
if tiled:
tile_size = (tile_size[0] * 8, tile_size[1] * 8)
tile_stride = (tile_stride[0] * 8, tile_stride[1] * 8)
hidden_state = self.tiled_encode(video, device, tile_size, tile_stride)
else:
hidden_state = self.single_encode(video, device)
hidden_state = hidden_state.squeeze(0)
hidden_states.append(hidden_state)
hidden_states = torch.stack(hidden_states)
return hidden_states
class WanMtxg2VAE(nn.Module):
def __init__(self, pretrained_path=None, z_dim=16):
super().__init__()
mean = [-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508, 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921]
std = [2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743, 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160]
self.mean = torch.tensor(mean)
self.std = torch.tensor(std)
self.scale = [self.mean, 1.0 / self.std]
# init model
self.model = (
WanVAE_(
dim=96,
z_dim=z_dim,
num_res_blocks=2,
dim_mult=[1, 2, 4, 4],
temperal_downsample=[False, True, True],
dropout=0.0,
pruning_rate=0.0,
)
.eval()
.requires_grad_(False)
)
if pretrained_path is not None:
self.model.load_state_dict(torch.load(pretrained_path, map_location="cpu"), assign=True)
self.upsampling_factor = 8
def to(self, *args, **kwargs):
self.mean = self.mean.to(*args, **kwargs)
self.std = self.std.to(*args, **kwargs)
self.scale = [self.mean, 1.0 / self.std]
self.model = self.model.to(*args, **kwargs)
return self
def build_1d_mask(self, length, left_bound, right_bound, border_width):
x = torch.ones((length,))
if not left_bound:
x[:border_width] = (torch.arange(border_width) + 1) / border_width
if not right_bound:
x[-border_width:] = torch.flip((torch.arange(border_width) + 1) / border_width, dims=(0,))
return x
def build_mask(self, data, is_bound, border_width):
_, _, _, H, W = data.shape
h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0])
w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1])
h = repeat(h, "H -> H W", H=H, W=W)
w = repeat(w, "W -> H W", H=H, W=W)
mask = torch.stack([h, w]).min(dim=0).values
mask = rearrange(mask, "H W -> 1 1 1 H W")
return mask
def tiled_decode(self, hidden_states, device, tile_size, tile_stride):
_, _, T, H, W = hidden_states.shape
size_h, size_w = tile_size
stride_h, stride_w = tile_stride
# Split tasks
tasks = []
for h in range(0, H, stride_h):
if h - stride_h >= 0 and h - stride_h + size_h >= H:
continue
for w in range(0, W, stride_w):
if w - stride_w >= 0 and w - stride_w + size_w >= W:
continue
h_, w_ = h + size_h, w + size_w
tasks.append((h, h_, w, w_))
data_device = "cpu" # TODO
computation_device = device
out_T = T * 4 - 3
weight = torch.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
values = torch.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
for h, h_, w, w_ in tasks: # tqdm(tasks, desc="VAE decoding"):
hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(computation_device)
hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(data_device)
mask = self.build_mask(
hidden_states_batch, is_bound=(h == 0, h_ >= H, w == 0, w_ >= W), border_width=((size_h - stride_h) * self.upsampling_factor, (size_w - stride_w) * self.upsampling_factor)
).to(dtype=hidden_states.dtype, device=data_device)
target_h = h * self.upsampling_factor
target_w = w * self.upsampling_factor
values[
:,
:,
:,
target_h : target_h + hidden_states_batch.shape[3],
target_w : target_w + hidden_states_batch.shape[4],
] += hidden_states_batch * mask
weight[
:,
:,
:,
target_h : target_h + hidden_states_batch.shape[3],
target_w : target_w + hidden_states_batch.shape[4],
] += mask
values = values / weight
values = values.clamp_(-1, 1)
return values
def tiled_encode(self, video, device, tile_size, tile_stride):
_, _, T, H, W = video.shape
size_h, size_w = tile_size
stride_h, stride_w = tile_stride
# Split tasks
tasks = []
for h in range(0, H, stride_h):
if h - stride_h >= 0 and h - stride_h + size_h >= H:
continue
for w in range(0, W, stride_w):
if w - stride_w >= 0 and w - stride_w + size_w >= W:
continue
h_, w_ = h + size_h, w + size_w
tasks.append((h, h_, w, w_))
data_device = "cpu"
computation_device = device
out_T = (T + 3) // 4
weight = torch.zeros((1, 1, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
values = torch.zeros((1, 16, out_T, H // self.upsampling_factor, W // self.upsampling_factor), dtype=video.dtype, device=data_device)
for h, h_, w, w_ in tasks: # tqdm(tasks, desc="VAE encoding"):
hidden_states_batch = video[:, :, :, h:h_, w:w_].to(computation_device)
hidden_states_batch = self.model.encode(hidden_states_batch, self.scale).to(data_device)
mask = self.build_mask(
hidden_states_batch, is_bound=(h == 0, h_ >= H, w == 0, w_ >= W), border_width=((size_h - stride_h) // self.upsampling_factor, (size_w - stride_w) // self.upsampling_factor)
).to(dtype=video.dtype, device=data_device)
target_h = h // self.upsampling_factor
target_w = w // self.upsampling_factor
values[
:,
:,
:,
target_h : target_h + hidden_states_batch.shape[3],
target_w : target_w + hidden_states_batch.shape[4],
] += hidden_states_batch * mask
weight[
:,
:,
:,
target_h : target_h + hidden_states_batch.shape[3],
target_w : target_w + hidden_states_batch.shape[4],
] += mask
values = values / weight
return values
def single_encode(self, video, device):
video = video.to(device)
x = self.model.encode(video, self.scale)
return x
def single_decode(self, hidden_state, device):
hidden_state = hidden_state.to(device)
video = self.model.decode(hidden_state, self.scale)
return video.clamp_(-1, 1)
def encode(self, videos, device, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
# videos: torch.Size([1, 3, 597, 352, 640]), device='cuda:0', dtype=torch.bfloat16
videos = [video.to("cpu") for video in videos]
hidden_states = []
for video in videos:
video = video.unsqueeze(0) # torch.Size([1, 3, 597, 352, 640]) torch.bfloat16 device(type='cpu')
if tiled: # True
tile_size = (tile_size[0] * 8, tile_size[1] * 8)
tile_stride = (tile_stride[0] * 8, tile_stride[1] * 8)
hidden_state = self.tiled_encode(video, device, tile_size, tile_stride)
else:
hidden_state = self.single_encode(video, device)
hidden_state = hidden_state.squeeze(0)
hidden_states.append(hidden_state)
hidden_states = torch.stack(hidden_states)
return hidden_states
def decode(self, hidden_states, device, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
videos = []
for hidden_state in hidden_states:
hidden_state = hidden_state.unsqueeze(0)
if tiled:
video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
else:
video = self.single_decode(hidden_state, device)
video = video.squeeze(0)
videos.append(video)
videos = torch.stack(videos)
return videos
#!/bin/bash
# set path and first
lightx2v_path=path to Lightx2v
model_path=path to Skywork/Matrix-Game-2.0
export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
python -m lightx2v.infer \
--model_cls wan2.1_sf_mtxg2 \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/matrix_game2/matrix_game2_gta_drive.json \
--prompt '' \
--image_path gta_drive/0003.png \
--save_result_path ${lightx2v_path}/save_results/output_lightx2v_matrix_game2_gta_drive.mp4 \
--seed 42
#!/bin/bash
# set path and first
lightx2v_path=/data/nvme2/wushuo/LightX2V
model_path=/data/nvme2/wushuo/hf_models/Skywork/Matrix-Game-2.0
export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
python -m lightx2v.infer \
--model_cls wan2.1_sf_mtxg2 \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/matrix_game2/matrix_game2_gta_drive_streaming.json \
--prompt '' \
--image_path gta_drive/0003.png \
--save_result_path ${lightx2v_path}/save_results/output_lightx2v_matrix_game2_gta_drive_streaming.mp4 \
--seed 42
#!/bin/bash
# set path and first
lightx2v_path=path to Lightx2v
model_path=path to Skywork/Matrix-Game-2.0
export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
python -m lightx2v.infer \
--model_cls wan2.1_sf_mtxg2 \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/matrix_game2/matrix_game2_templerun.json \
--prompt '' \
--image_path templerun/0005.png \
--save_result_path ${lightx2v_path}/save_results/output_lightx2v_matrix_game2_templerun.mp4 \
--seed 42
#!/bin/bash
# set path and first
lightx2v_path=path to Lightx2v
model_path=path to Skywork/Matrix-Game-2.0
export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
python -m lightx2v.infer \
--model_cls wan2.1_sf_mtxg2 \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/matrix_game2/matrix_game2_templerun_streaming.json \
--prompt '' \
--image_path templerun/0005.png \
--save_result_path ${lightx2v_path}/save_results/output_lightx2v_matrix_game2_templerun_streaming.mp4 \
--seed 42
#!/bin/bash
# set path and first
lightx2v_path=path to Lightx2v
model_path=path to Skywork/Matrix-Game-2.0
export CUDA_VISIBLE_DEVICES=
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
python -m lightx2v.infer \
--model_cls wan2.1_sf_mtxg2 \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/matrix_game2/matrix_game2_universal.json \
--prompt '' \
--image_path universal/0007.png \
--save_result_path ${lightx2v_path}/save_results/output_lightx2v_matrix_game2_universal.mp4 \
--seed 42
#!/bin/bash
# set path and first
lightx2v_path=path to Lightx2v
model_path=path to Skywork/Matrix-Game-2.0
export CUDA_VISIBLE_DEVICES=0
# set environment variables
source ${lightx2v_path}/scripts/base/base.sh
python -m lightx2v.infer \
--model_cls wan2.1_sf_mtxg2 \
--task i2v \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/matrix_game2/matrix_game2_universal_streaming.json \
--prompt '' \
--image_path universal/0007.png \
--save_result_path ${lightx2v_path}/save_results/output_lightx2v_matrix_game2_universal_streaming.mp4 \
--seed 42
......@@ -17,4 +17,4 @@ python -m lightx2v.infer \
--sf_model_path $sf_model_path \
--config_json ${lightx2v_path}/configs/self_forcing/wan_t2v_sf.json \
--prompt 'A stylish woman strolls down a bustling Tokyo street, the warm glow of neon lights and animated city signs casting vibrant reflections. She wears a sleek black leather jacket paired with a flowing red dress and black boots, her black purse slung over her shoulder. Sunglasses perched on her nose and a bold red lipstick add to her confident, casual demeanor. The street is damp and reflective, creating a mirror-like effect that enhances the colorful lights and shadows. Pedestrians move about, adding to the lively atmosphere. The scene is captured in a dynamic medium shot with the woman walking slightly to one side, highlighting her graceful strides.' \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v_sf.mp4
--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v_sf.mp4
#!/bin/bash
export CUDA_VISIBLE_DEVICES=
# set path and first
export lightx2v_path=
export model_path=
# check section
if [ -z "${CUDA_VISIBLE_DEVICES}" ]; then
cuda_devices=0
echo "Warn: CUDA_VISIBLE_DEVICES is not set, using default value: ${cuda_devices}, change at shell script or set env variable."
export CUDA_VISIBLE_DEVICES=${cuda_devices}
fi
if [ -z "${lightx2v_path}" ]; then
echo "Error: lightx2v_path is not set. Please set this variable first."
exit 1
fi
if [ -z "${model_path}" ]; then
echo "Error: model_path is not set. Please set this variable first."
exit 1
fi
export TOKENIZERS_PARALLELISM=false
export PYTHONPATH=${lightx2v_path}:$PYTHONPATH
export DTYPE=BF16
export PROFILING_DEBUG_LEVEL=2
python -m lightx2v.infer \
--model_cls qwen_image \
--task i2i \
--model_path $model_path \
--config_json ${lightx2v_path}/configs/qwen_image/qwen_image_i2i_2509.json \
--prompt "Have the two characters swap clothes and stand in front of the castle." \
--negative_prompt " " \
--image_path 1.jpeg,2.jpeg \
--save_result_path ${lightx2v_path}/save_results/qwen_image_i2i_2509.png \
--seed 0
......@@ -17,4 +17,4 @@ python -m lightx2v.infer \
--sf_model_path $sf_model_path \
--config_json ${lightx2v_path}/configs/self_forcing/wan_t2v_sf.json \
--prompt 'A stylish woman strolls down a bustling Tokyo street, the warm glow of neon lights and animated city signs casting vibrant reflections. She wears a sleek black leather jacket paired with a flowing red dress and black boots, her black purse slung over her shoulder. Sunglasses perched on her nose and a bold red lipstick add to her confident, casual demeanor. The street is damp and reflective, creating a mirror-like effect that enhances the colorful lights and shadows. Pedestrians move about, adding to the lively atmosphere. The scene is captured in a dynamic medium shot with the woman walking slightly to one side, highlighting her graceful strides.' \
--save_video_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v_sf.mp4
--save_result_path ${lightx2v_path}/save_results/output_lightx2v_wan_t2v_sf.mp4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment