Unverified Commit 9826b8ca authored by Watebear's avatar Watebear Committed by GitHub
Browse files

[feat]: support matrix game2 universal, gta_drive, templerun & streaming mode

parent 44e215f3
{
"infer_steps": 50,
"target_video_length": 150,
"num_output_frames": 150,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "flash_attn2",
"cross_attn_1_type": "flash_attn2",
"cross_attn_2_type": "flash_attn2",
"seed": 0,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": false,
"cpu_offload": false,
"sf_config": {
"local_attn_size": 6,
"shift": 5.0,
"num_frame_per_block": 3,
"num_transformer_blocks": 30,
"frame_seq_length": 880,
"num_output_frames": 150,
"num_inference_steps": 1000,
"denoising_step_list": [1000.0000, 908.8427, 713.9794]
},
"sub_model_folder": "gta_distilled_model",
"sub_model_name": "gta_keyboard2dim.safetensors",
"mode": "gta_drive",
"streaming": false,
"action_config": {
"blocks": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
"enable_keyboard": true,
"enable_mouse": true,
"heads_num": 16,
"hidden_size": 128,
"img_hidden_size": 1536,
"keyboard_dim_in": 4,
"keyboard_hidden_dim": 1024,
"mouse_dim_in": 2,
"mouse_hidden_dim": 1024,
"mouse_qk_dim_list": [
8,
28,
28
],
"patch_size": [
1,
2,
2
],
"qk_norm": true,
"qkv_bias": false,
"rope_dim_list": [
8,
28,
28
],
"rope_theta": 256,
"vae_time_compression_ratio": 4,
"windows_size": 3
},
"dim": 1536,
"eps": 1e-06,
"ffn_dim": 8960,
"freq_dim": 256,
"in_dim": 36,
"inject_sample_info": false,
"model_type": "i2v",
"num_heads": 12,
"num_layers": 30,
"out_dim": 16
}
{
"infer_steps": 50,
"target_video_length": 360,
"num_output_frames": 360,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "flash_attn2",
"cross_attn_1_type": "flash_attn2",
"cross_attn_2_type": "flash_attn2",
"seed": 0,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": false,
"cpu_offload": false,
"sf_config": {
"local_attn_size": 6,
"shift": 5.0,
"num_frame_per_block": 3,
"num_transformer_blocks": 30,
"frame_seq_length": 880,
"num_output_frames": 360,
"num_inference_steps": 1000,
"denoising_step_list": [1000.0000, 908.8427, 713.9794]
},
"sub_model_folder": "gta_distilled_model",
"sub_model_name": "gta_keyboard2dim.safetensors",
"mode": "gta_drive",
"streaming": true,
"action_config": {
"blocks": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
"enable_keyboard": true,
"enable_mouse": true,
"heads_num": 16,
"hidden_size": 128,
"img_hidden_size": 1536,
"keyboard_dim_in": 4,
"keyboard_hidden_dim": 1024,
"mouse_dim_in": 2,
"mouse_hidden_dim": 1024,
"mouse_qk_dim_list": [
8,
28,
28
],
"patch_size": [
1,
2,
2
],
"qk_norm": true,
"qkv_bias": false,
"rope_dim_list": [
8,
28,
28
],
"rope_theta": 256,
"vae_time_compression_ratio": 4,
"windows_size": 3
},
"dim": 1536,
"eps": 1e-06,
"ffn_dim": 8960,
"freq_dim": 256,
"in_dim": 36,
"inject_sample_info": false,
"model_type": "i2v",
"num_heads": 12,
"num_layers": 30,
"out_dim": 16
}
{
"infer_steps": 50,
"target_video_length": 150,
"num_output_frames": 150,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "flash_attn2",
"cross_attn_1_type": "flash_attn2",
"cross_attn_2_type": "flash_attn2",
"seed": 0,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": false,
"cpu_offload": false,
"sf_config": {
"local_attn_size": 6,
"shift": 5.0,
"num_frame_per_block": 3,
"num_transformer_blocks": 30,
"frame_seq_length": 880,
"num_output_frames": 150,
"num_inference_steps": 1000,
"denoising_step_list": [1000.0000, 908.8427, 713.9794]
},
"sub_model_folder": "templerun_distilled_model",
"sub_model_name": "templerun_7dim_onlykey.safetensors",
"mode": "templerun",
"streaming": false,
"action_config": {
"blocks": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
"enable_keyboard": true,
"enable_mouse": false,
"heads_num": 16,
"hidden_size": 128,
"img_hidden_size": 1536,
"keyboard_dim_in": 7,
"keyboard_hidden_dim": 1024,
"patch_size": [
1,
2,
2
],
"qk_norm": true,
"qkv_bias": false,
"rope_dim_list": [
8,
28,
28
],
"rope_theta": 256,
"vae_time_compression_ratio": 4,
"windows_size": 3
},
"dim": 1536,
"eps": 1e-06,
"ffn_dim": 8960,
"freq_dim": 256,
"in_dim": 36,
"inject_sample_info": false,
"model_type": "i2v",
"num_heads": 12,
"num_layers": 30,
"out_dim": 16
}
{
"infer_steps": 50,
"target_video_length": 360,
"num_output_frames": 360,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "flash_attn2",
"cross_attn_1_type": "flash_attn2",
"cross_attn_2_type": "flash_attn2",
"seed": 0,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": false,
"cpu_offload": false,
"sf_config": {
"local_attn_size": 6,
"shift": 5.0,
"num_frame_per_block": 3,
"num_transformer_blocks": 30,
"frame_seq_length": 880,
"num_output_frames": 360,
"num_inference_steps": 1000,
"denoising_step_list": [1000.0000, 908.8427, 713.9794]
},
"sub_model_folder": "templerun_distilled_model",
"sub_model_name": "templerun_7dim_onlykey.safetensors",
"mode": "templerun",
"streaming": true,
"action_config": {
"blocks": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
"enable_keyboard": true,
"enable_mouse": true,
"heads_num": 16,
"hidden_size": 128,
"img_hidden_size": 1536,
"keyboard_dim_in": 4,
"keyboard_hidden_dim": 1024,
"mouse_dim_in": 2,
"mouse_hidden_dim": 1024,
"mouse_qk_dim_list": [
8,
28,
28
],
"patch_size": [
1,
2,
2
],
"qk_norm": true,
"qkv_bias": false,
"rope_dim_list": [
8,
28,
28
],
"rope_theta": 256,
"vae_time_compression_ratio": 4,
"windows_size": 3
},
"dim": 1536,
"eps": 1e-06,
"ffn_dim": 8960,
"freq_dim": 256,
"in_dim": 36,
"inject_sample_info": false,
"model_type": "i2v",
"num_heads": 12,
"num_layers": 30,
"out_dim": 16
}
{
"infer_steps": 50,
"target_video_length": 150,
"num_output_frames": 150,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "flash_attn2",
"cross_attn_1_type": "flash_attn2",
"cross_attn_2_type": "flash_attn2",
"seed": 0,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": false,
"cpu_offload": false,
"sf_config": {
"local_attn_size": 6,
"shift": 5.0,
"num_frame_per_block": 3,
"num_transformer_blocks": 30,
"frame_seq_length": 880,
"num_output_frames": 150,
"num_inference_steps": 1000,
"denoising_step_list": [1000.0000, 908.8427, 713.9794]
},
"sub_model_folder": "base_distilled_model",
"sub_model_name": "base_distill.safetensors",
"mode": "universal",
"streaming": false,
"action_config": {
"blocks": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
"enable_keyboard": true,
"enable_mouse": true,
"heads_num": 16,
"hidden_size": 128,
"img_hidden_size": 1536,
"keyboard_dim_in": 4,
"keyboard_hidden_dim": 1024,
"mouse_dim_in": 2,
"mouse_hidden_dim": 1024,
"mouse_qk_dim_list": [
8,
28,
28
],
"patch_size": [
1,
2,
2
],
"qk_norm": true,
"qkv_bias": false,
"rope_dim_list": [
8,
28,
28
],
"rope_theta": 256,
"vae_time_compression_ratio": 4,
"windows_size": 3
},
"dim": 1536,
"eps": 1e-06,
"ffn_dim": 8960,
"freq_dim": 256,
"in_dim": 36,
"inject_sample_info": false,
"model_type": "i2v",
"num_heads": 12,
"num_layers": 30,
"out_dim": 16
}
{
"infer_steps": 50,
"target_video_length": 360,
"num_output_frames": 360,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "flash_attn2",
"cross_attn_1_type": "flash_attn2",
"cross_attn_2_type": "flash_attn2",
"seed": 0,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": false,
"cpu_offload": false,
"sf_config": {
"local_attn_size": 6,
"shift": 5.0,
"num_frame_per_block": 3,
"num_transformer_blocks": 30,
"frame_seq_length": 880,
"num_output_frames": 360,
"num_inference_steps": 1000,
"denoising_step_list": [1000.0000, 908.8427, 713.9794]
},
"sub_model_folder": "base_distilled_model",
"sub_model_name": "base_distill.safetensors",
"mode": "universal",
"streaming": true,
"action_config": {
"blocks": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
"enable_keyboard": true,
"enable_mouse": true,
"heads_num": 16,
"hidden_size": 128,
"img_hidden_size": 1536,
"keyboard_dim_in": 4,
"keyboard_hidden_dim": 1024,
"mouse_dim_in": 2,
"mouse_hidden_dim": 1024,
"mouse_qk_dim_list": [
8,
28,
28
],
"patch_size": [
1,
2,
2
],
"qk_norm": true,
"qkv_bias": false,
"rope_dim_list": [
8,
28,
28
],
"rope_theta": 256,
"vae_time_compression_ratio": 4,
"windows_size": 3
},
"dim": 1536,
"eps": 1e-06,
"ffn_dim": 8960,
"freq_dim": 256,
"in_dim": 36,
"inject_sample_info": false,
"model_type": "i2v",
"num_heads": 12,
"num_layers": 30,
"out_dim": 16
}
......@@ -51,7 +51,7 @@ except ImportError:
try:
from torchao.quantization.utils import quant_int8_per_token_matmul, quantize_activation_per_token_absmax
except ModuleNotFoundError:
except ImportError:
quant_int8_per_token_matmul, quantize_activation_per_token_absmax = None, None
try:
......@@ -61,7 +61,7 @@ except ImportError:
try:
import marlin_cuda_quant
except ModuleNotFoundError:
except ImportError:
marlin_cuda_quant = None
......
......@@ -9,6 +9,7 @@ from lightx2v.models.runners.qwen_image.qwen_image_runner import QwenImageRunner
from lightx2v.models.runners.wan.wan_animate_runner import WanAnimateRunner # noqa: F401
from lightx2v.models.runners.wan.wan_audio_runner import Wan22AudioRunner, WanAudioRunner # noqa: F401
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner # noqa: F401
from lightx2v.models.runners.wan.wan_matrix_game2_runner import WanSFMtxg2Runner # noqa: F401
from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner # noqa: F401
from lightx2v.models.runners.wan.wan_sf_runner import WanSFRunner # noqa: F401
from lightx2v.models.runners.wan.wan_vace_runner import WanVaceRunner # noqa: F401
......@@ -39,6 +40,7 @@ def main():
"wan2.1_distill",
"wan2.1_vace",
"wan2.1_sf",
"wan2.1_sf_mtxg2",
"seko_talk",
"wan2.2_moe",
"wan2.2",
......
......@@ -3,7 +3,7 @@ import torch.nn as nn
try:
from vllm import _custom_ops as ops
except ModuleNotFoundError:
except ImportError:
ops = None
try:
......@@ -13,7 +13,7 @@ except ImportError:
try:
from torchao.quantization.utils import quant_int8_per_token_matmul, quantize_activation_per_token_absmax
except ModuleNotFoundError:
except ImportError:
quant_int8_per_token_matmul, quantize_activation_per_token_absmax = None, None
try:
......
# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import logging
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from diffusers.models import ModelMixin
from lightx2v.models.input_encoders.hf.wan.matrix_game2.tokenizers import HuggingfaceTokenizer
from lightx2v.models.input_encoders.hf.wan.xlm_roberta.model import VisionTransformer
class SelfAttention(nn.Module):
def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.eps = eps
# layers
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask):
"""
x: [B, L, C].
"""
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
# compute attention
p = self.dropout.p if self.training else 0.0
x = F.scaled_dot_product_attention(q, k, v, mask, p)
x = x.permute(0, 2, 1, 3).reshape(b, s, c)
# output
x = self.o(x)
x = self.dropout(x)
return x
class AttentionBlock(nn.Module):
def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.post_norm = post_norm
self.eps = eps
# layers
self.attn = SelfAttention(dim, num_heads, dropout, eps)
self.norm1 = nn.LayerNorm(dim, eps=eps)
self.ffn = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim), nn.Dropout(dropout))
self.norm2 = nn.LayerNorm(dim, eps=eps)
def forward(self, x, mask):
if self.post_norm:
x = self.norm1(x + self.attn(x, mask))
x = self.norm2(x + self.ffn(x))
else:
x = x + self.attn(self.norm1(x), mask)
x = x + self.ffn(self.norm2(x))
return x
class XLMRoberta(nn.Module):
"""
XLMRobertaModel with no pooler and no LM head.
"""
def __init__(self, vocab_size=250002, max_seq_len=514, type_size=1, pad_id=1, dim=1024, num_heads=16, num_layers=24, post_norm=True, dropout=0.1, eps=1e-5):
super().__init__()
self.vocab_size = vocab_size
self.max_seq_len = max_seq_len
self.type_size = type_size
self.pad_id = pad_id
self.dim = dim
self.num_heads = num_heads
self.num_layers = num_layers
self.post_norm = post_norm
self.eps = eps
# embeddings
self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
self.type_embedding = nn.Embedding(type_size, dim)
self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
self.dropout = nn.Dropout(dropout)
# blocks
self.blocks = nn.ModuleList([AttentionBlock(dim, num_heads, post_norm, dropout, eps) for _ in range(num_layers)])
# norm layer
self.norm = nn.LayerNorm(dim, eps=eps)
def forward(self, ids):
"""
ids: [B, L] of torch.LongTensor.
"""
b, s = ids.shape
mask = ids.ne(self.pad_id).long()
# embeddings
x = self.token_embedding(ids) + self.type_embedding(torch.zeros_like(ids)) + self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
if self.post_norm:
x = self.norm(x)
x = self.dropout(x)
# blocks
mask = torch.where(mask.view(b, 1, 1, s).gt(0), 0.0, torch.finfo(x.dtype).min)
for block in self.blocks:
x = block(x, mask)
# output
if not self.post_norm:
x = self.norm(x)
return x
class XLMRobertaWithHead(XLMRoberta):
def __init__(self, **kwargs):
self.out_dim = kwargs.pop("out_dim")
super().__init__(**kwargs)
# head
mid_dim = (self.dim + self.out_dim) // 2
self.head = nn.Sequential(nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(), nn.Linear(mid_dim, self.out_dim, bias=False))
def forward(self, ids):
# xlm-roberta
x = super().forward(ids)
# average pooling
mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
x = (x * mask).sum(dim=1) / mask.sum(dim=1)
# head
x = self.head(x)
return x
class XLMRobertaCLIP(nn.Module):
def __init__(
self,
dtype=torch.float16,
embed_dim=1024,
image_size=224,
patch_size=14,
vision_dim=1280,
vision_mlp_ratio=4,
vision_heads=16,
vision_layers=32,
vision_pool="token",
vision_pre_norm=True,
vision_post_norm=False,
activation="gelu",
vocab_size=250002,
max_text_len=514,
type_size=1,
pad_id=1,
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0,
norm_eps=1e-5,
quantized=False,
quant_scheme=None,
text_dim=1024,
text_heads=16,
text_layers=24,
text_post_norm=True,
text_dropout=0.1,
):
super().__init__()
self.embed_dim = embed_dim
self.image_size = image_size
self.patch_size = patch_size
self.vision_dim = vision_dim
self.vision_mlp_ratio = vision_mlp_ratio
self.vision_heads = vision_heads
self.vision_layers = vision_layers
self.vision_pre_norm = vision_pre_norm
self.vision_post_norm = vision_post_norm
self.activation = activation
self.vocab_size = vocab_size
self.max_text_len = max_text_len
self.type_size = type_size
self.pad_id = pad_id
self.norm_eps = norm_eps
# models
self.visual = VisionTransformer(
dtype=dtype,
image_size=image_size,
patch_size=patch_size,
dim=vision_dim,
mlp_ratio=vision_mlp_ratio,
out_dim=embed_dim,
num_heads=vision_heads,
num_layers=vision_layers,
pool_type=vision_pool,
pre_norm=vision_pre_norm,
post_norm=vision_post_norm,
activation=activation,
attn_dropout=attn_dropout,
proj_dropout=proj_dropout,
embedding_dropout=embedding_dropout,
norm_eps=norm_eps,
quantized=quantized,
quant_scheme=quant_scheme,
)
self.textual = XLMRobertaWithHead(
vocab_size=vocab_size,
max_seq_len=max_text_len,
type_size=type_size,
pad_id=pad_id,
dim=text_dim,
out_dim=embed_dim,
num_heads=text_heads,
num_layers=text_layers,
post_norm=text_post_norm,
dropout=text_dropout,
)
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
def _clip(pretrained=False, pretrained_name=None, model_cls=XLMRobertaCLIP, return_transforms=False, return_tokenizer=False, tokenizer_padding="eos", dtype=torch.float32, device="cpu", **kwargs):
# init a model on device
with torch.device(device):
model = model_cls(**kwargs)
# set device
model = model.to(dtype=dtype, device=device)
output = (model,)
# init transforms
if return_transforms:
# mean and std
if "siglip" in pretrained_name.lower():
mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
else:
mean = [0.48145466, 0.4578275, 0.40821073]
std = [0.26862954, 0.26130258, 0.27577711]
# transforms
transforms = T.Compose([T.Resize((model.image_size, model.image_size), interpolation=T.InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=mean, std=std)])
output += (transforms,)
return output[0] if len(output) == 1 else output
def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-roberta-large-vit-huge-14", **kwargs):
cfg = dict(
embed_dim=1024,
image_size=224,
patch_size=14,
vision_dim=1280,
vision_mlp_ratio=4,
vision_heads=16,
vision_layers=32,
vision_pool="token",
activation="gelu",
vocab_size=250002,
max_text_len=514,
type_size=1,
pad_id=1,
text_dim=1024,
text_heads=16,
text_layers=24,
text_post_norm=True,
text_dropout=0.1,
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0,
)
cfg.update(**kwargs)
return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
class CLIPModel(ModelMixin):
def __init__(self, checkpoint_path, tokenizer_path):
super().__init__()
self.checkpoint_path = checkpoint_path
self.tokenizer_path = tokenizer_path
# init model
self.model, self.transforms = clip_xlm_roberta_vit_h_14(
pretrained=False,
return_transforms=True,
return_tokenizer=False,
)
self.model = self.model.eval().requires_grad_(False)
logging.info(f"loading {checkpoint_path}")
self.model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
# init tokenizer
self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=self.model.max_text_len - 2, clean="whitespace")
def encode_video(self, video):
# preprocess
b, c, t, h, w = video.shape
video = video.transpose(1, 2)
video = video.reshape(b * t, c, h, w)
size = (self.model.image_size,) * 2
video = F.interpolate(video, size=size, mode="bicubic", align_corners=False)
video = self.transforms.transforms[-1](video.mul_(0.5).add_(0.5))
# forward
with torch.amp.autocast(dtype=self.dtype, device_type=self.device.type):
out = self.model.visual(video, use_31_block=True)
return out
def forward(self, videos):
# preprocess
size = (self.model.image_size,) * 2
videos = torch.cat([F.interpolate(u.transpose(0, 1), size=size, mode="bicubic", align_corners=False) for u in videos])
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
# forward
with torch.amp.autocast("cuda", dtype=self.dtype):
out = self.model.visual(videos, use_31_block=True)
return out
import random
import torch
def combine_data(data, num_frames=57, keyboard_dim=6, mouse=True):
assert num_frames % 4 == 1
keyboard_condition = torch.zeros((num_frames, keyboard_dim))
if mouse:
mouse_condition = torch.zeros((num_frames, 2))
current_frame = 0
selections = [12]
while current_frame < num_frames:
rd_frame = selections[random.randint(0, len(selections) - 1)]
rd = random.randint(0, len(data) - 1)
k = data[rd]["keyboard_condition"]
if mouse:
m = data[rd]["mouse_condition"]
if current_frame == 0:
keyboard_condition[:1] = k[:1]
if mouse:
mouse_condition[:1] = m[:1]
current_frame = 1
else:
rd_frame = min(rd_frame, num_frames - current_frame)
repeat_time = rd_frame // 4
keyboard_condition[current_frame : current_frame + rd_frame] = k.repeat(repeat_time, 1)
if mouse:
mouse_condition[current_frame : current_frame + rd_frame] = m.repeat(repeat_time, 1)
current_frame += rd_frame
if mouse:
return {"keyboard_condition": keyboard_condition, "mouse_condition": mouse_condition}
return {"keyboard_condition": keyboard_condition}
def Bench_actions_universal(num_frames, num_samples_per_action=4):
actions_single_action = [
"forward",
# "back",
"left",
"right",
]
actions_double_action = [
"forward_left",
"forward_right",
# "back_left",
# "back_right",
]
actions_single_camera = [
"camera_l",
"camera_r",
# "camera_ur",
# "camera_ul",
# "camera_dl",
# "camera_dr"
# "camera_up",
# "camera_down",
]
actions_to_test = actions_double_action * 5 + actions_single_camera * 5 + actions_single_action * 5
for action in actions_single_action + actions_double_action:
for camera in actions_single_camera:
double_action = f"{action}_{camera}"
actions_to_test.append(double_action)
# print("length of actions: ", len(actions_to_test))
base_action = actions_single_action + actions_single_camera
KEYBOARD_IDX = {"forward": 0, "back": 1, "left": 2, "right": 3}
CAM_VALUE = 0.1
CAMERA_VALUE_MAP = {
"camera_up": [CAM_VALUE, 0],
"camera_down": [-CAM_VALUE, 0],
"camera_l": [0, -CAM_VALUE],
"camera_r": [0, CAM_VALUE],
"camera_ur": [CAM_VALUE, CAM_VALUE],
"camera_ul": [CAM_VALUE, -CAM_VALUE],
"camera_dr": [-CAM_VALUE, CAM_VALUE],
"camera_dl": [-CAM_VALUE, -CAM_VALUE],
}
data = []
for action_name in actions_to_test:
keyboard_condition = [[0, 0, 0, 0] for _ in range(num_samples_per_action)]
mouse_condition = [[0, 0] for _ in range(num_samples_per_action)]
for sub_act in base_action:
if sub_act not in action_name: # 只处理action_name包含的动作
continue
# print(f"action name: {action_name} sub_act: {sub_act}")
if sub_act in CAMERA_VALUE_MAP:
mouse_condition = [CAMERA_VALUE_MAP[sub_act] for _ in range(num_samples_per_action)]
elif sub_act in KEYBOARD_IDX:
col = KEYBOARD_IDX[sub_act]
for row in keyboard_condition:
row[col] = 1
data.append({"keyboard_condition": torch.tensor(keyboard_condition), "mouse_condition": torch.tensor(mouse_condition)})
return combine_data(data, num_frames, keyboard_dim=4, mouse=True)
def Bench_actions_gta_drive(num_frames, num_samples_per_action=4):
actions_single_action = [
"forward",
"back",
]
actions_single_camera = [
"camera_l",
"camera_r",
]
actions_to_test = actions_single_camera * 2 + actions_single_action * 2
for action in actions_single_action:
for camera in actions_single_camera:
double_action = f"{action}_{camera}"
actions_to_test.append(double_action)
# print("length of actions: ", len(actions_to_test))
base_action = actions_single_action + actions_single_camera
KEYBOARD_IDX = {"forward": 0, "back": 1}
CAM_VALUE = 0.1
CAMERA_VALUE_MAP = {
"camera_l": [0, -CAM_VALUE],
"camera_r": [0, CAM_VALUE],
}
data = []
for action_name in actions_to_test:
keyboard_condition = [[0, 0] for _ in range(num_samples_per_action)]
mouse_condition = [[0, 0] for _ in range(num_samples_per_action)]
for sub_act in base_action:
if sub_act not in action_name: # 只处理action_name包含的动作
continue
# print(f"action name: {action_name} sub_act: {sub_act}")
if sub_act in CAMERA_VALUE_MAP:
mouse_condition = [CAMERA_VALUE_MAP[sub_act] for _ in range(num_samples_per_action)]
elif sub_act in KEYBOARD_IDX:
col = KEYBOARD_IDX[sub_act]
for row in keyboard_condition:
row[col] = 1
data.append({"keyboard_condition": torch.tensor(keyboard_condition), "mouse_condition": torch.tensor(mouse_condition)})
return combine_data(data, num_frames, keyboard_dim=2, mouse=True)
def Bench_actions_templerun(num_frames, num_samples_per_action=4):
actions_single_action = ["jump", "slide", "leftside", "rightside", "turnleft", "turnright", "nomove"]
actions_to_test = actions_single_action
base_action = actions_single_action
KEYBOARD_IDX = {"nomove": 0, "jump": 1, "slide": 2, "turnleft": 3, "turnright": 4, "leftside": 5, "rightside": 6}
data = []
for action_name in actions_to_test:
keyboard_condition = [[0, 0, 0, 0, 0, 0, 0] for _ in range(num_samples_per_action)]
for sub_act in base_action:
if sub_act not in action_name: # 只处理action_name包含的动作
continue
# print(f"action name: {action_name} sub_act: {sub_act}")
elif sub_act in KEYBOARD_IDX:
col = KEYBOARD_IDX[sub_act]
for row in keyboard_condition:
row[col] = 1
data.append({"keyboard_condition": torch.tensor(keyboard_condition)})
return combine_data(data, num_frames, keyboard_dim=7, mouse=False)
class MatrixGame2_Bench:
def __init__(self):
self.deivce = torch.device("cuda")
self.weight_dtype = torch.bfloat16
def get_conditions(self, mode, num_frames):
conditional_dict = {}
if mode == "universal":
cond_data = Bench_actions_universal(num_frames)
mouse_condition = cond_data["mouse_condition"].unsqueeze(0).to(device=self.device, dtype=self.weight_dtype)
conditional_dict["mouse_cond"] = mouse_condition
elif mode == "gta_drive":
cond_data = Bench_actions_gta_drive(num_frames)
mouse_condition = cond_data["mouse_condition"].unsqueeze(0).to(device=self.device, dtype=self.weight_dtype)
conditional_dict["mouse_cond"] = mouse_condition
else:
cond_data = Bench_actions_templerun(num_frames)
keyboard_condition = cond_data["keyboard_condition"].unsqueeze(0).to(device=self.device, dtype=self.weight_dtype)
conditional_dict["keyboard_cond"] = keyboard_condition
return conditional_dict
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import html
import string
import ftfy
import regex as re
from transformers import AutoTokenizer
__all__ = ["HuggingfaceTokenizer"]
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
def canonicalize(text, keep_punctuation_exact_string=None):
text = text.replace("_", " ")
if keep_punctuation_exact_string:
text = keep_punctuation_exact_string.join(part.translate(str.maketrans("", "", string.punctuation)) for part in text.split(keep_punctuation_exact_string))
else:
text = text.translate(str.maketrans("", "", string.punctuation))
text = text.lower()
text = re.sub(r"\s+", " ", text)
return text.strip()
class HuggingfaceTokenizer:
def __init__(self, name, seq_len=None, clean=None, **kwargs):
assert clean in (None, "whitespace", "lower", "canonicalize")
self.name = name
self.seq_len = seq_len
self.clean = clean
# init tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
self.vocab_size = self.tokenizer.vocab_size
def __call__(self, sequence, **kwargs):
return_mask = kwargs.pop("return_mask", False)
# arguments
_kwargs = {"return_tensors": "pt"}
if self.seq_len is not None:
_kwargs.update({"padding": "max_length", "truncation": True, "max_length": self.seq_len})
_kwargs.update(**kwargs)
# tokenization
if isinstance(sequence, str):
sequence = [sequence]
if self.clean:
sequence = [self._clean(u) for u in sequence]
ids = self.tokenizer(sequence, **_kwargs)
# output
if return_mask:
return ids.input_ids, ids.attention_mask
else:
return ids.input_ids
def _clean(self, text):
if self.clean == "whitespace":
text = whitespace_clean(basic_clean(text))
elif self.clean == "lower":
text = whitespace_clean(basic_clean(text)).lower()
elif self.clean == "canonicalize":
text = canonicalize(basic_clean(text))
return text
from typing import List, Tuple, Union
import torch
def _to_tuple(x, dim=2):
if isinstance(x, int):
return (x,) * dim
elif len(x) == dim:
return x
else:
raise ValueError(f"Expected length {dim} or int, but got {x}")
def get_meshgrid_nd(start, *args, dim=2):
"""
Get n-D meshgrid with start, stop and num.
Args:
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
n-tuples.
*args: See above.
dim (int): Dimension of the meshgrid. Defaults to 2.
Returns:
grid (np.ndarray): [dim, ...]
"""
if len(args) == 0:
# start is grid_size
num = _to_tuple(start, dim=dim)
start = (0,) * dim
stop = num
elif len(args) == 1:
# start is start, args[0] is stop, step is 1
start = _to_tuple(start, dim=dim)
stop = _to_tuple(args[0], dim=dim)
num = [stop[i] - start[i] for i in range(dim)]
elif len(args) == 2:
# start is start, args[0] is stop, args[1] is num
start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
else:
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
# PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
axis_grid = []
for i in range(dim):
a, b, n = start[i], stop[i], num[i]
g = torch.linspace(a, b, n + 1, dtype=torch.float32, device=torch.cuda.current_device())[:n]
axis_grid.append(g)
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
return grid
#################################################################################
# Rotary Positional Embedding Functions #
#################################################################################
# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
def reshape_for_broadcast(
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
x: torch.Tensor,
head_first=False,
):
"""
Reshape frequency tensor for broadcasting it with another tensor.
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
for the purpose of broadcasting the frequency tensor during element-wise operations.
Notes:
When using FlashMHAModified, head_first should be False.
When using Attention, head_first should be True.
Args:
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
x (torch.Tensor): Target tensor for broadcasting compatibility.
head_first (bool): head dimension first (except batch dim) or not.
Returns:
torch.Tensor: Reshaped frequency tensor.
Raises:
AssertionError: If the frequency tensor doesn't match the expected shape.
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
"""
ndim = x.ndim
assert 0 <= 1 < ndim
if isinstance(freqs_cis, tuple):
# freqs_cis: (cos, sin) in real space
if head_first:
assert freqs_cis[0].shape == (
x.shape[-2],
x.shape[-1],
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
else:
# assert freqs_cis[0].shape == (
# x.shape[1],
# x.shape[-1],
# ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
# shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
shape = [1, freqs_cis[0].shape[0], 1, freqs_cis[0].shape[1]]
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
else:
# freqs_cis: values in complex space
if head_first:
assert freqs_cis.shape == (
x.shape[-2],
x.shape[-1],
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
else:
assert freqs_cis.shape == (
x.shape[1],
x.shape[-1],
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def rotate_half(x):
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
head_first: bool = False,
start_offset: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor.
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
returned as real tensors.
Args:
xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
head_first (bool): head dimension first (except batch dim) or not.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
# print(freqs_cis[0].shape, xq.shape, xk.shape)
xk_out = None
assert isinstance(freqs_cis, tuple)
if isinstance(freqs_cis, tuple):
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
cos, sin = cos.to(xq.device), sin.to(xq.device)
# real * cos - imag * sin
# imag * cos + real * sin
xq_out = (xq.float() * cos[:, start_offset : start_offset + xq.shape[1], :, :] + rotate_half(xq.float()) * sin[:, start_offset : start_offset + xq.shape[1], :, :]).type_as(xq)
xk_out = (xk.float() * cos[:, start_offset : start_offset + xk.shape[1], :, :] + rotate_half(xk.float()) * sin[:, start_offset : start_offset + xk.shape[1], :, :]).type_as(xk)
else:
# view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
# (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
# view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2]
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
return xq_out, xk_out
def get_nd_rotary_pos_embed(
rope_dim_list,
start,
*args,
theta=10000.0,
use_real=False,
theta_rescale_factor: Union[float, List[float]] = 1.0,
interpolation_factor: Union[float, List[float]] = 1.0,
):
"""
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
Args:
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
sum(rope_dim_list) should equal to head_dim of attention layer.
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
*args: See above.
theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
part and an imaginary part separately.
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
Returns:
pos_embed (torch.Tensor): [HW, D/2]
"""
grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H]
if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
assert len(theta_rescale_factor) == len(rope_dim_list), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
assert len(interpolation_factor) == len(rope_dim_list), "len(interpolation_factor) should equal to len(rope_dim_list)"
# use 1/ndim of dimensions to encode grid_axis
embs = []
for i in range(len(rope_dim_list)):
emb = get_1d_rotary_pos_embed(
rope_dim_list[i],
grid[i].reshape(-1),
theta,
use_real=use_real,
theta_rescale_factor=theta_rescale_factor[i],
interpolation_factor=interpolation_factor[i],
) # 2 x [WHD, rope_dim_list[i]]
embs.append(emb)
if use_real:
cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
return cos, sin
else:
emb = torch.cat(embs, dim=1) # (WHD, D/2)
return emb
def get_1d_rotary_pos_embed(
dim: int,
pos: Union[torch.FloatTensor, int],
theta: float = 10000.0,
use_real: bool = False,
theta_rescale_factor: float = 1.0,
interpolation_factor: float = 1.0,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Precompute the frequency tensor for complex exponential (cis) with given dimensions.
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.
Args:
dim (int): Dimension of the frequency tensor.
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
use_real (bool, optional): If True, return real part and imaginary part separately.
Otherwise, return complex numbers.
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
Returns:
freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
"""
if isinstance(pos, int):
pos = torch.arange(pos, device=torch.cuda.current_device()).float()
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
if theta_rescale_factor != 1.0:
theta *= theta_rescale_factor ** (dim / (dim - 2))
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=torch.cuda.current_device())[: (dim // 2)].float() / dim)) # [D/2]
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
if use_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
return freqs_cos, freqs_sin
else:
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis
import torch
from lightx2v.models.networks.wan.infer.module_io import GridOutput, WanPreInferModuleOutput
from lightx2v.models.networks.wan.infer.self_forcing.pre_infer import WanSFPreInfer, sinusoidal_embedding_1d
from lightx2v.utils.envs import *
def cond_current(conditional_dict, current_start_frame, num_frame_per_block, replace=None, mode="universal"):
new_cond = {}
new_cond["cond_concat"] = conditional_dict["image_encoder_output"]["cond_concat"][:, :, current_start_frame : current_start_frame + num_frame_per_block]
new_cond["visual_context"] = conditional_dict["image_encoder_output"]["visual_context"]
if replace:
if current_start_frame == 0:
last_frame_num = 1 + 4 * (num_frame_per_block - 1)
else:
last_frame_num = 4 * num_frame_per_block
final_frame = 1 + 4 * (current_start_frame + num_frame_per_block - 1)
if mode != "templerun":
conditional_dict["text_encoder_output"]["mouse_cond"][:, -last_frame_num + final_frame : final_frame] = replace["mouse"][None, None, :].repeat(1, last_frame_num, 1)
conditional_dict["text_encoder_output"]["keyboard_cond"][:, -last_frame_num + final_frame : final_frame] = replace["keyboard"][None, None, :].repeat(1, last_frame_num, 1)
if mode != "templerun":
new_cond["mouse_cond"] = conditional_dict["text_encoder_output"]["mouse_cond"][:, : 1 + 4 * (current_start_frame + num_frame_per_block - 1)]
new_cond["keyboard_cond"] = conditional_dict["text_encoder_output"]["keyboard_cond"][:, : 1 + 4 * (current_start_frame + num_frame_per_block - 1)]
if replace:
return new_cond, conditional_dict
else:
return new_cond
# @amp.autocast(enabled=False)
def rope_params(max_seq_len, dim, theta=10000):
assert dim % 2 == 0
freqs = torch.outer(torch.arange(max_seq_len), 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)))
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
class WanMtxg2PreInfer(WanSFPreInfer):
def __init__(self, config):
super().__init__(config)
d = config["dim"] // config["num_heads"]
self.freqs = torch.cat([rope_params(1024, d - 4 * (d // 6)), rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6))], dim=1).to(torch.device("cuda"))
self.dim = config["dim"]
def img_emb(self, weights, x):
x = weights.img_emb_0.apply(x)
x = weights.img_emb_1.apply(x.squeeze(0))
x = torch.nn.functional.gelu(x, approximate="none")
x = weights.img_emb_3.apply(x)
x = weights.img_emb_4.apply(x)
x = x.unsqueeze(0)
return x
@torch.no_grad()
def infer(self, weights, inputs, kv_start=0, kv_end=0):
x = self.scheduler.latents_input
t = self.scheduler.timestep_input
current_start_frame = self.scheduler.seg_index * self.scheduler.num_frame_per_block
if self.config["streaming"]:
current_actions = inputs["current_actions"]
current_conditional_dict, _ = cond_current(inputs, current_start_frame, self.scheduler.num_frame_per_block, replace=current_actions, mode=self.config["mode"])
else:
current_conditional_dict = cond_current(inputs, current_start_frame, self.scheduler.num_frame_per_block, mode=self.config["mode"])
cond_concat = current_conditional_dict["cond_concat"]
visual_context = current_conditional_dict["visual_context"]
x = torch.cat([x.unsqueeze(0), cond_concat], dim=1)
# embeddings
x = weights.patch_embedding.apply(x)
grid_sizes_t, grid_sizes_h, grid_sizes_w = torch.tensor(x.shape[2:], dtype=torch.long)
grid_sizes = GridOutput(tensor=torch.tensor([[grid_sizes_t, grid_sizes_h, grid_sizes_w]], dtype=torch.int32, device=x.device), tuple=(grid_sizes_t, grid_sizes_h, grid_sizes_w))
x = x.flatten(2).transpose(1, 2) # B FHW C'
seq_lens = torch.tensor([u.size(0) for u in x], dtype=torch.long, device=torch.device("cuda"))
assert seq_lens[0] <= 15 * 1 * 880
embed_tmp = sinusoidal_embedding_1d(self.freq_dim, t.flatten()).type_as(x) # torch.Size([3, 256])
embed = self.time_embedding(weights, embed_tmp) # torch.Size([3, 1536])
embed0 = self.time_projection(weights, embed).unflatten(dim=0, sizes=t.shape)
# context
context_lens = None
context = self.img_emb(weights, visual_context)
return WanPreInferModuleOutput(
embed=embed,
grid_sizes=grid_sizes,
x=x.squeeze(0),
embed0=embed0.squeeze(0),
seq_lens=seq_lens,
freqs=self.freqs,
context=context[0],
conditional_dict=current_conditional_dict,
)
......@@ -20,3 +20,4 @@ class WanPreInferModuleOutput:
freqs: torch.Tensor
context: torch.Tensor
adapter_args: Dict[str, Any] = field(default_factory=dict)
conditional_dict: Dict[str, Any] = field(default_factory=dict)
import json
import os
import torch
from safetensors import safe_open
from lightx2v.models.networks.wan.infer.matrix_game2.pre_infer import WanMtxg2PreInfer
from lightx2v.models.networks.wan.infer.matrix_game2.transformer_infer import WanMtxg2TransformerInfer
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.models.networks.wan.sf_model import WanSFModel
from lightx2v.models.networks.wan.weights.matrix_game2.pre_weights import WanMtxg2PreWeights
from lightx2v.models.networks.wan.weights.matrix_game2.transformer_weights import WanActionTransformerWeights
from lightx2v.utils.envs import *
from lightx2v.utils.utils import *
class WanSFMtxg2Model(WanSFModel):
pre_weight_class = WanMtxg2PreWeights
transformer_weight_class = WanActionTransformerWeights
def __init__(self, model_path, config, device):
super().__init__(model_path, config, device)
def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
with safe_open(file_path, framework="pt", device=str(self.device)) as f:
return {key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())) for key in f.keys()}
def _load_ckpt(self, unified_dtype, sensitive_layer):
file_path = os.path.join(self.config["model_path"], f"{self.config['sub_model_folder']}/{self.config['sub_model_name']}")
_weight_dict = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
weight_dict = {}
for k, v in _weight_dict.items():
name = k[6:]
weight = v.to(torch.bfloat16).to(self.device)
weight_dict.update({name: weight})
del _weight_dict
return weight_dict
def _init_infer_class(self):
# update config by real model config
with open(os.path.join(self.config["model_path"], self.config["sub_model_folder"], "config.json")) as f:
model_config = json.load(f)
for k in model_config.keys():
self.config[k] = model_config[k]
self.pre_infer_class = WanMtxg2PreInfer
self.post_infer_class = WanPostInfer
self.transformer_infer_class = WanMtxg2TransformerInfer
......@@ -11,7 +11,8 @@ from lightx2v.models.networks.wan.model import WanModel
class WanSFModel(WanModel):
def __init__(self, model_path, config, device):
super().__init__(model_path, config, device)
self.to_cuda()
if config["model_cls"] not in ["wan2.1_sf_mtxg2"]:
self.to_cuda()
def _load_ckpt(self, unified_dtype, sensitive_layer):
sf_confg = self.config["sf_config"]
......
from lightx2v.common.modules.weight_module import WeightModule
from lightx2v.utils.registry_factory import (
CONV3D_WEIGHT_REGISTER,
LN_WEIGHT_REGISTER,
MM_WEIGHT_REGISTER,
)
class WanMtxg2PreWeights(WeightModule):
def __init__(self, config):
super().__init__()
self.in_dim = config["in_dim"]
self.dim = config["dim"]
self.patch_size = (1, 2, 2)
self.config = config
# patch
self.add_module(
"patch_embedding",
CONV3D_WEIGHT_REGISTER["Default"]("patch_embedding.weight", "patch_embedding.bias", stride=self.patch_size),
)
# time
self.add_module(
"time_embedding_0",
MM_WEIGHT_REGISTER["Default"]("time_embedding.0.weight", "time_embedding.0.bias"),
)
self.add_module(
"time_embedding_2",
MM_WEIGHT_REGISTER["Default"]("time_embedding.2.weight", "time_embedding.2.bias"),
)
self.add_module(
"time_projection_1",
MM_WEIGHT_REGISTER["Default"]("time_projection.1.weight", "time_projection.1.bias"),
)
# img_emb
self.add_module(
"img_emb_0",
LN_WEIGHT_REGISTER["Default"]("img_emb.proj.0.weight", "img_emb.proj.0.bias", eps=1e-5),
)
self.add_module(
"img_emb_1",
MM_WEIGHT_REGISTER["Default"]("img_emb.proj.1.weight", "img_emb.proj.1.bias"),
)
self.add_module(
"img_emb_3",
MM_WEIGHT_REGISTER["Default"]("img_emb.proj.3.weight", "img_emb.proj.3.bias"),
)
self.add_module(
"img_emb_4",
LN_WEIGHT_REGISTER["Default"]("img_emb.proj.4.weight", "img_emb.proj.4.bias", eps=1e-5),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment