Unverified Commit 9826b8ca authored by Watebear's avatar Watebear Committed by GitHub
Browse files

[feat]: support matrix game2 universal, gta_drive, templerun & streaming mode

parent 44e215f3
{
"infer_steps": 50,
"target_video_length": 150,
"num_output_frames": 150,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "flash_attn2",
"cross_attn_1_type": "flash_attn2",
"cross_attn_2_type": "flash_attn2",
"seed": 0,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": false,
"cpu_offload": false,
"sf_config": {
"local_attn_size": 6,
"shift": 5.0,
"num_frame_per_block": 3,
"num_transformer_blocks": 30,
"frame_seq_length": 880,
"num_output_frames": 150,
"num_inference_steps": 1000,
"denoising_step_list": [1000.0000, 908.8427, 713.9794]
},
"sub_model_folder": "gta_distilled_model",
"sub_model_name": "gta_keyboard2dim.safetensors",
"mode": "gta_drive",
"streaming": false,
"action_config": {
"blocks": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
"enable_keyboard": true,
"enable_mouse": true,
"heads_num": 16,
"hidden_size": 128,
"img_hidden_size": 1536,
"keyboard_dim_in": 4,
"keyboard_hidden_dim": 1024,
"mouse_dim_in": 2,
"mouse_hidden_dim": 1024,
"mouse_qk_dim_list": [
8,
28,
28
],
"patch_size": [
1,
2,
2
],
"qk_norm": true,
"qkv_bias": false,
"rope_dim_list": [
8,
28,
28
],
"rope_theta": 256,
"vae_time_compression_ratio": 4,
"windows_size": 3
},
"dim": 1536,
"eps": 1e-06,
"ffn_dim": 8960,
"freq_dim": 256,
"in_dim": 36,
"inject_sample_info": false,
"model_type": "i2v",
"num_heads": 12,
"num_layers": 30,
"out_dim": 16
}
{
"infer_steps": 50,
"target_video_length": 360,
"num_output_frames": 360,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "flash_attn2",
"cross_attn_1_type": "flash_attn2",
"cross_attn_2_type": "flash_attn2",
"seed": 0,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": false,
"cpu_offload": false,
"sf_config": {
"local_attn_size": 6,
"shift": 5.0,
"num_frame_per_block": 3,
"num_transformer_blocks": 30,
"frame_seq_length": 880,
"num_output_frames": 360,
"num_inference_steps": 1000,
"denoising_step_list": [1000.0000, 908.8427, 713.9794]
},
"sub_model_folder": "gta_distilled_model",
"sub_model_name": "gta_keyboard2dim.safetensors",
"mode": "gta_drive",
"streaming": true,
"action_config": {
"blocks": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
"enable_keyboard": true,
"enable_mouse": true,
"heads_num": 16,
"hidden_size": 128,
"img_hidden_size": 1536,
"keyboard_dim_in": 4,
"keyboard_hidden_dim": 1024,
"mouse_dim_in": 2,
"mouse_hidden_dim": 1024,
"mouse_qk_dim_list": [
8,
28,
28
],
"patch_size": [
1,
2,
2
],
"qk_norm": true,
"qkv_bias": false,
"rope_dim_list": [
8,
28,
28
],
"rope_theta": 256,
"vae_time_compression_ratio": 4,
"windows_size": 3
},
"dim": 1536,
"eps": 1e-06,
"ffn_dim": 8960,
"freq_dim": 256,
"in_dim": 36,
"inject_sample_info": false,
"model_type": "i2v",
"num_heads": 12,
"num_layers": 30,
"out_dim": 16
}
{
"infer_steps": 50,
"target_video_length": 150,
"num_output_frames": 150,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "flash_attn2",
"cross_attn_1_type": "flash_attn2",
"cross_attn_2_type": "flash_attn2",
"seed": 0,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": false,
"cpu_offload": false,
"sf_config": {
"local_attn_size": 6,
"shift": 5.0,
"num_frame_per_block": 3,
"num_transformer_blocks": 30,
"frame_seq_length": 880,
"num_output_frames": 150,
"num_inference_steps": 1000,
"denoising_step_list": [1000.0000, 908.8427, 713.9794]
},
"sub_model_folder": "templerun_distilled_model",
"sub_model_name": "templerun_7dim_onlykey.safetensors",
"mode": "templerun",
"streaming": false,
"action_config": {
"blocks": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
"enable_keyboard": true,
"enable_mouse": false,
"heads_num": 16,
"hidden_size": 128,
"img_hidden_size": 1536,
"keyboard_dim_in": 7,
"keyboard_hidden_dim": 1024,
"patch_size": [
1,
2,
2
],
"qk_norm": true,
"qkv_bias": false,
"rope_dim_list": [
8,
28,
28
],
"rope_theta": 256,
"vae_time_compression_ratio": 4,
"windows_size": 3
},
"dim": 1536,
"eps": 1e-06,
"ffn_dim": 8960,
"freq_dim": 256,
"in_dim": 36,
"inject_sample_info": false,
"model_type": "i2v",
"num_heads": 12,
"num_layers": 30,
"out_dim": 16
}
{
"infer_steps": 50,
"target_video_length": 360,
"num_output_frames": 360,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "flash_attn2",
"cross_attn_1_type": "flash_attn2",
"cross_attn_2_type": "flash_attn2",
"seed": 0,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": false,
"cpu_offload": false,
"sf_config": {
"local_attn_size": 6,
"shift": 5.0,
"num_frame_per_block": 3,
"num_transformer_blocks": 30,
"frame_seq_length": 880,
"num_output_frames": 360,
"num_inference_steps": 1000,
"denoising_step_list": [1000.0000, 908.8427, 713.9794]
},
"sub_model_folder": "templerun_distilled_model",
"sub_model_name": "templerun_7dim_onlykey.safetensors",
"mode": "templerun",
"streaming": true,
"action_config": {
"blocks": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
"enable_keyboard": true,
"enable_mouse": true,
"heads_num": 16,
"hidden_size": 128,
"img_hidden_size": 1536,
"keyboard_dim_in": 4,
"keyboard_hidden_dim": 1024,
"mouse_dim_in": 2,
"mouse_hidden_dim": 1024,
"mouse_qk_dim_list": [
8,
28,
28
],
"patch_size": [
1,
2,
2
],
"qk_norm": true,
"qkv_bias": false,
"rope_dim_list": [
8,
28,
28
],
"rope_theta": 256,
"vae_time_compression_ratio": 4,
"windows_size": 3
},
"dim": 1536,
"eps": 1e-06,
"ffn_dim": 8960,
"freq_dim": 256,
"in_dim": 36,
"inject_sample_info": false,
"model_type": "i2v",
"num_heads": 12,
"num_layers": 30,
"out_dim": 16
}
{
"infer_steps": 50,
"target_video_length": 150,
"num_output_frames": 150,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "flash_attn2",
"cross_attn_1_type": "flash_attn2",
"cross_attn_2_type": "flash_attn2",
"seed": 0,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": false,
"cpu_offload": false,
"sf_config": {
"local_attn_size": 6,
"shift": 5.0,
"num_frame_per_block": 3,
"num_transformer_blocks": 30,
"frame_seq_length": 880,
"num_output_frames": 150,
"num_inference_steps": 1000,
"denoising_step_list": [1000.0000, 908.8427, 713.9794]
},
"sub_model_folder": "base_distilled_model",
"sub_model_name": "base_distill.safetensors",
"mode": "universal",
"streaming": false,
"action_config": {
"blocks": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
"enable_keyboard": true,
"enable_mouse": true,
"heads_num": 16,
"hidden_size": 128,
"img_hidden_size": 1536,
"keyboard_dim_in": 4,
"keyboard_hidden_dim": 1024,
"mouse_dim_in": 2,
"mouse_hidden_dim": 1024,
"mouse_qk_dim_list": [
8,
28,
28
],
"patch_size": [
1,
2,
2
],
"qk_norm": true,
"qkv_bias": false,
"rope_dim_list": [
8,
28,
28
],
"rope_theta": 256,
"vae_time_compression_ratio": 4,
"windows_size": 3
},
"dim": 1536,
"eps": 1e-06,
"ffn_dim": 8960,
"freq_dim": 256,
"in_dim": 36,
"inject_sample_info": false,
"model_type": "i2v",
"num_heads": 12,
"num_layers": 30,
"out_dim": 16
}
{
"infer_steps": 50,
"target_video_length": 360,
"num_output_frames": 360,
"text_len": 512,
"target_height": 480,
"target_width": 832,
"self_attn_1_type": "flash_attn2",
"cross_attn_1_type": "flash_attn2",
"cross_attn_2_type": "flash_attn2",
"seed": 0,
"sample_guide_scale": 6,
"sample_shift": 8,
"enable_cfg": false,
"cpu_offload": false,
"sf_config": {
"local_attn_size": 6,
"shift": 5.0,
"num_frame_per_block": 3,
"num_transformer_blocks": 30,
"frame_seq_length": 880,
"num_output_frames": 360,
"num_inference_steps": 1000,
"denoising_step_list": [1000.0000, 908.8427, 713.9794]
},
"sub_model_folder": "base_distilled_model",
"sub_model_name": "base_distill.safetensors",
"mode": "universal",
"streaming": true,
"action_config": {
"blocks": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
"enable_keyboard": true,
"enable_mouse": true,
"heads_num": 16,
"hidden_size": 128,
"img_hidden_size": 1536,
"keyboard_dim_in": 4,
"keyboard_hidden_dim": 1024,
"mouse_dim_in": 2,
"mouse_hidden_dim": 1024,
"mouse_qk_dim_list": [
8,
28,
28
],
"patch_size": [
1,
2,
2
],
"qk_norm": true,
"qkv_bias": false,
"rope_dim_list": [
8,
28,
28
],
"rope_theta": 256,
"vae_time_compression_ratio": 4,
"windows_size": 3
},
"dim": 1536,
"eps": 1e-06,
"ffn_dim": 8960,
"freq_dim": 256,
"in_dim": 36,
"inject_sample_info": false,
"model_type": "i2v",
"num_heads": 12,
"num_layers": 30,
"out_dim": 16
}
...@@ -51,7 +51,7 @@ except ImportError: ...@@ -51,7 +51,7 @@ except ImportError:
try: try:
from torchao.quantization.utils import quant_int8_per_token_matmul, quantize_activation_per_token_absmax from torchao.quantization.utils import quant_int8_per_token_matmul, quantize_activation_per_token_absmax
except ModuleNotFoundError: except ImportError:
quant_int8_per_token_matmul, quantize_activation_per_token_absmax = None, None quant_int8_per_token_matmul, quantize_activation_per_token_absmax = None, None
try: try:
...@@ -61,7 +61,7 @@ except ImportError: ...@@ -61,7 +61,7 @@ except ImportError:
try: try:
import marlin_cuda_quant import marlin_cuda_quant
except ModuleNotFoundError: except ImportError:
marlin_cuda_quant = None marlin_cuda_quant = None
......
...@@ -9,6 +9,7 @@ from lightx2v.models.runners.qwen_image.qwen_image_runner import QwenImageRunner ...@@ -9,6 +9,7 @@ from lightx2v.models.runners.qwen_image.qwen_image_runner import QwenImageRunner
from lightx2v.models.runners.wan.wan_animate_runner import WanAnimateRunner # noqa: F401 from lightx2v.models.runners.wan.wan_animate_runner import WanAnimateRunner # noqa: F401
from lightx2v.models.runners.wan.wan_audio_runner import Wan22AudioRunner, WanAudioRunner # noqa: F401 from lightx2v.models.runners.wan.wan_audio_runner import Wan22AudioRunner, WanAudioRunner # noqa: F401
from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner # noqa: F401 from lightx2v.models.runners.wan.wan_distill_runner import WanDistillRunner # noqa: F401
from lightx2v.models.runners.wan.wan_matrix_game2_runner import WanSFMtxg2Runner # noqa: F401
from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner # noqa: F401 from lightx2v.models.runners.wan.wan_runner import Wan22MoeRunner, WanRunner # noqa: F401
from lightx2v.models.runners.wan.wan_sf_runner import WanSFRunner # noqa: F401 from lightx2v.models.runners.wan.wan_sf_runner import WanSFRunner # noqa: F401
from lightx2v.models.runners.wan.wan_vace_runner import WanVaceRunner # noqa: F401 from lightx2v.models.runners.wan.wan_vace_runner import WanVaceRunner # noqa: F401
...@@ -39,6 +40,7 @@ def main(): ...@@ -39,6 +40,7 @@ def main():
"wan2.1_distill", "wan2.1_distill",
"wan2.1_vace", "wan2.1_vace",
"wan2.1_sf", "wan2.1_sf",
"wan2.1_sf_mtxg2",
"seko_talk", "seko_talk",
"wan2.2_moe", "wan2.2_moe",
"wan2.2", "wan2.2",
......
...@@ -3,7 +3,7 @@ import torch.nn as nn ...@@ -3,7 +3,7 @@ import torch.nn as nn
try: try:
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
except ModuleNotFoundError: except ImportError:
ops = None ops = None
try: try:
...@@ -13,7 +13,7 @@ except ImportError: ...@@ -13,7 +13,7 @@ except ImportError:
try: try:
from torchao.quantization.utils import quant_int8_per_token_matmul, quantize_activation_per_token_absmax from torchao.quantization.utils import quant_int8_per_token_matmul, quantize_activation_per_token_absmax
except ModuleNotFoundError: except ImportError:
quant_int8_per_token_matmul, quantize_activation_per_token_absmax = None, None quant_int8_per_token_matmul, quantize_activation_per_token_absmax = None, None
try: try:
......
# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import logging
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from diffusers.models import ModelMixin
from lightx2v.models.input_encoders.hf.wan.matrix_game2.tokenizers import HuggingfaceTokenizer
from lightx2v.models.input_encoders.hf.wan.xlm_roberta.model import VisionTransformer
class SelfAttention(nn.Module):
def __init__(self, dim, num_heads, dropout=0.1, eps=1e-5):
assert dim % num_heads == 0
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.eps = eps
# layers
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask):
"""
x: [B, L, C].
"""
b, s, c, n, d = *x.size(), self.num_heads, self.head_dim
# compute query, key, value
q = self.q(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
k = self.k(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
v = self.v(x).reshape(b, s, n, d).permute(0, 2, 1, 3)
# compute attention
p = self.dropout.p if self.training else 0.0
x = F.scaled_dot_product_attention(q, k, v, mask, p)
x = x.permute(0, 2, 1, 3).reshape(b, s, c)
# output
x = self.o(x)
x = self.dropout(x)
return x
class AttentionBlock(nn.Module):
def __init__(self, dim, num_heads, post_norm, dropout=0.1, eps=1e-5):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.post_norm = post_norm
self.eps = eps
# layers
self.attn = SelfAttention(dim, num_heads, dropout, eps)
self.norm1 = nn.LayerNorm(dim, eps=eps)
self.ffn = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim), nn.Dropout(dropout))
self.norm2 = nn.LayerNorm(dim, eps=eps)
def forward(self, x, mask):
if self.post_norm:
x = self.norm1(x + self.attn(x, mask))
x = self.norm2(x + self.ffn(x))
else:
x = x + self.attn(self.norm1(x), mask)
x = x + self.ffn(self.norm2(x))
return x
class XLMRoberta(nn.Module):
"""
XLMRobertaModel with no pooler and no LM head.
"""
def __init__(self, vocab_size=250002, max_seq_len=514, type_size=1, pad_id=1, dim=1024, num_heads=16, num_layers=24, post_norm=True, dropout=0.1, eps=1e-5):
super().__init__()
self.vocab_size = vocab_size
self.max_seq_len = max_seq_len
self.type_size = type_size
self.pad_id = pad_id
self.dim = dim
self.num_heads = num_heads
self.num_layers = num_layers
self.post_norm = post_norm
self.eps = eps
# embeddings
self.token_embedding = nn.Embedding(vocab_size, dim, padding_idx=pad_id)
self.type_embedding = nn.Embedding(type_size, dim)
self.pos_embedding = nn.Embedding(max_seq_len, dim, padding_idx=pad_id)
self.dropout = nn.Dropout(dropout)
# blocks
self.blocks = nn.ModuleList([AttentionBlock(dim, num_heads, post_norm, dropout, eps) for _ in range(num_layers)])
# norm layer
self.norm = nn.LayerNorm(dim, eps=eps)
def forward(self, ids):
"""
ids: [B, L] of torch.LongTensor.
"""
b, s = ids.shape
mask = ids.ne(self.pad_id).long()
# embeddings
x = self.token_embedding(ids) + self.type_embedding(torch.zeros_like(ids)) + self.pos_embedding(self.pad_id + torch.cumsum(mask, dim=1) * mask)
if self.post_norm:
x = self.norm(x)
x = self.dropout(x)
# blocks
mask = torch.where(mask.view(b, 1, 1, s).gt(0), 0.0, torch.finfo(x.dtype).min)
for block in self.blocks:
x = block(x, mask)
# output
if not self.post_norm:
x = self.norm(x)
return x
class XLMRobertaWithHead(XLMRoberta):
def __init__(self, **kwargs):
self.out_dim = kwargs.pop("out_dim")
super().__init__(**kwargs)
# head
mid_dim = (self.dim + self.out_dim) // 2
self.head = nn.Sequential(nn.Linear(self.dim, mid_dim, bias=False), nn.GELU(), nn.Linear(mid_dim, self.out_dim, bias=False))
def forward(self, ids):
# xlm-roberta
x = super().forward(ids)
# average pooling
mask = ids.ne(self.pad_id).unsqueeze(-1).to(x)
x = (x * mask).sum(dim=1) / mask.sum(dim=1)
# head
x = self.head(x)
return x
class XLMRobertaCLIP(nn.Module):
def __init__(
self,
dtype=torch.float16,
embed_dim=1024,
image_size=224,
patch_size=14,
vision_dim=1280,
vision_mlp_ratio=4,
vision_heads=16,
vision_layers=32,
vision_pool="token",
vision_pre_norm=True,
vision_post_norm=False,
activation="gelu",
vocab_size=250002,
max_text_len=514,
type_size=1,
pad_id=1,
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0,
norm_eps=1e-5,
quantized=False,
quant_scheme=None,
text_dim=1024,
text_heads=16,
text_layers=24,
text_post_norm=True,
text_dropout=0.1,
):
super().__init__()
self.embed_dim = embed_dim
self.image_size = image_size
self.patch_size = patch_size
self.vision_dim = vision_dim
self.vision_mlp_ratio = vision_mlp_ratio
self.vision_heads = vision_heads
self.vision_layers = vision_layers
self.vision_pre_norm = vision_pre_norm
self.vision_post_norm = vision_post_norm
self.activation = activation
self.vocab_size = vocab_size
self.max_text_len = max_text_len
self.type_size = type_size
self.pad_id = pad_id
self.norm_eps = norm_eps
# models
self.visual = VisionTransformer(
dtype=dtype,
image_size=image_size,
patch_size=patch_size,
dim=vision_dim,
mlp_ratio=vision_mlp_ratio,
out_dim=embed_dim,
num_heads=vision_heads,
num_layers=vision_layers,
pool_type=vision_pool,
pre_norm=vision_pre_norm,
post_norm=vision_post_norm,
activation=activation,
attn_dropout=attn_dropout,
proj_dropout=proj_dropout,
embedding_dropout=embedding_dropout,
norm_eps=norm_eps,
quantized=quantized,
quant_scheme=quant_scheme,
)
self.textual = XLMRobertaWithHead(
vocab_size=vocab_size,
max_seq_len=max_text_len,
type_size=type_size,
pad_id=pad_id,
dim=text_dim,
out_dim=embed_dim,
num_heads=text_heads,
num_layers=text_layers,
post_norm=text_post_norm,
dropout=text_dropout,
)
self.log_scale = nn.Parameter(math.log(1 / 0.07) * torch.ones([]))
def _clip(pretrained=False, pretrained_name=None, model_cls=XLMRobertaCLIP, return_transforms=False, return_tokenizer=False, tokenizer_padding="eos", dtype=torch.float32, device="cpu", **kwargs):
# init a model on device
with torch.device(device):
model = model_cls(**kwargs)
# set device
model = model.to(dtype=dtype, device=device)
output = (model,)
# init transforms
if return_transforms:
# mean and std
if "siglip" in pretrained_name.lower():
mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
else:
mean = [0.48145466, 0.4578275, 0.40821073]
std = [0.26862954, 0.26130258, 0.27577711]
# transforms
transforms = T.Compose([T.Resize((model.image_size, model.image_size), interpolation=T.InterpolationMode.BICUBIC), T.ToTensor(), T.Normalize(mean=mean, std=std)])
output += (transforms,)
return output[0] if len(output) == 1 else output
def clip_xlm_roberta_vit_h_14(pretrained=False, pretrained_name="open-clip-xlm-roberta-large-vit-huge-14", **kwargs):
cfg = dict(
embed_dim=1024,
image_size=224,
patch_size=14,
vision_dim=1280,
vision_mlp_ratio=4,
vision_heads=16,
vision_layers=32,
vision_pool="token",
activation="gelu",
vocab_size=250002,
max_text_len=514,
type_size=1,
pad_id=1,
text_dim=1024,
text_heads=16,
text_layers=24,
text_post_norm=True,
text_dropout=0.1,
attn_dropout=0.0,
proj_dropout=0.0,
embedding_dropout=0.0,
)
cfg.update(**kwargs)
return _clip(pretrained, pretrained_name, XLMRobertaCLIP, **cfg)
class CLIPModel(ModelMixin):
def __init__(self, checkpoint_path, tokenizer_path):
super().__init__()
self.checkpoint_path = checkpoint_path
self.tokenizer_path = tokenizer_path
# init model
self.model, self.transforms = clip_xlm_roberta_vit_h_14(
pretrained=False,
return_transforms=True,
return_tokenizer=False,
)
self.model = self.model.eval().requires_grad_(False)
logging.info(f"loading {checkpoint_path}")
self.model.load_state_dict(torch.load(checkpoint_path, map_location="cpu"))
# init tokenizer
self.tokenizer = HuggingfaceTokenizer(name=tokenizer_path, seq_len=self.model.max_text_len - 2, clean="whitespace")
def encode_video(self, video):
# preprocess
b, c, t, h, w = video.shape
video = video.transpose(1, 2)
video = video.reshape(b * t, c, h, w)
size = (self.model.image_size,) * 2
video = F.interpolate(video, size=size, mode="bicubic", align_corners=False)
video = self.transforms.transforms[-1](video.mul_(0.5).add_(0.5))
# forward
with torch.amp.autocast(dtype=self.dtype, device_type=self.device.type):
out = self.model.visual(video, use_31_block=True)
return out
def forward(self, videos):
# preprocess
size = (self.model.image_size,) * 2
videos = torch.cat([F.interpolate(u.transpose(0, 1), size=size, mode="bicubic", align_corners=False) for u in videos])
videos = self.transforms.transforms[-1](videos.mul_(0.5).add_(0.5))
# forward
with torch.amp.autocast("cuda", dtype=self.dtype):
out = self.model.visual(videos, use_31_block=True)
return out
import random
import torch
def combine_data(data, num_frames=57, keyboard_dim=6, mouse=True):
assert num_frames % 4 == 1
keyboard_condition = torch.zeros((num_frames, keyboard_dim))
if mouse:
mouse_condition = torch.zeros((num_frames, 2))
current_frame = 0
selections = [12]
while current_frame < num_frames:
rd_frame = selections[random.randint(0, len(selections) - 1)]
rd = random.randint(0, len(data) - 1)
k = data[rd]["keyboard_condition"]
if mouse:
m = data[rd]["mouse_condition"]
if current_frame == 0:
keyboard_condition[:1] = k[:1]
if mouse:
mouse_condition[:1] = m[:1]
current_frame = 1
else:
rd_frame = min(rd_frame, num_frames - current_frame)
repeat_time = rd_frame // 4
keyboard_condition[current_frame : current_frame + rd_frame] = k.repeat(repeat_time, 1)
if mouse:
mouse_condition[current_frame : current_frame + rd_frame] = m.repeat(repeat_time, 1)
current_frame += rd_frame
if mouse:
return {"keyboard_condition": keyboard_condition, "mouse_condition": mouse_condition}
return {"keyboard_condition": keyboard_condition}
def Bench_actions_universal(num_frames, num_samples_per_action=4):
actions_single_action = [
"forward",
# "back",
"left",
"right",
]
actions_double_action = [
"forward_left",
"forward_right",
# "back_left",
# "back_right",
]
actions_single_camera = [
"camera_l",
"camera_r",
# "camera_ur",
# "camera_ul",
# "camera_dl",
# "camera_dr"
# "camera_up",
# "camera_down",
]
actions_to_test = actions_double_action * 5 + actions_single_camera * 5 + actions_single_action * 5
for action in actions_single_action + actions_double_action:
for camera in actions_single_camera:
double_action = f"{action}_{camera}"
actions_to_test.append(double_action)
# print("length of actions: ", len(actions_to_test))
base_action = actions_single_action + actions_single_camera
KEYBOARD_IDX = {"forward": 0, "back": 1, "left": 2, "right": 3}
CAM_VALUE = 0.1
CAMERA_VALUE_MAP = {
"camera_up": [CAM_VALUE, 0],
"camera_down": [-CAM_VALUE, 0],
"camera_l": [0, -CAM_VALUE],
"camera_r": [0, CAM_VALUE],
"camera_ur": [CAM_VALUE, CAM_VALUE],
"camera_ul": [CAM_VALUE, -CAM_VALUE],
"camera_dr": [-CAM_VALUE, CAM_VALUE],
"camera_dl": [-CAM_VALUE, -CAM_VALUE],
}
data = []
for action_name in actions_to_test:
keyboard_condition = [[0, 0, 0, 0] for _ in range(num_samples_per_action)]
mouse_condition = [[0, 0] for _ in range(num_samples_per_action)]
for sub_act in base_action:
if sub_act not in action_name: # 只处理action_name包含的动作
continue
# print(f"action name: {action_name} sub_act: {sub_act}")
if sub_act in CAMERA_VALUE_MAP:
mouse_condition = [CAMERA_VALUE_MAP[sub_act] for _ in range(num_samples_per_action)]
elif sub_act in KEYBOARD_IDX:
col = KEYBOARD_IDX[sub_act]
for row in keyboard_condition:
row[col] = 1
data.append({"keyboard_condition": torch.tensor(keyboard_condition), "mouse_condition": torch.tensor(mouse_condition)})
return combine_data(data, num_frames, keyboard_dim=4, mouse=True)
def Bench_actions_gta_drive(num_frames, num_samples_per_action=4):
actions_single_action = [
"forward",
"back",
]
actions_single_camera = [
"camera_l",
"camera_r",
]
actions_to_test = actions_single_camera * 2 + actions_single_action * 2
for action in actions_single_action:
for camera in actions_single_camera:
double_action = f"{action}_{camera}"
actions_to_test.append(double_action)
# print("length of actions: ", len(actions_to_test))
base_action = actions_single_action + actions_single_camera
KEYBOARD_IDX = {"forward": 0, "back": 1}
CAM_VALUE = 0.1
CAMERA_VALUE_MAP = {
"camera_l": [0, -CAM_VALUE],
"camera_r": [0, CAM_VALUE],
}
data = []
for action_name in actions_to_test:
keyboard_condition = [[0, 0] for _ in range(num_samples_per_action)]
mouse_condition = [[0, 0] for _ in range(num_samples_per_action)]
for sub_act in base_action:
if sub_act not in action_name: # 只处理action_name包含的动作
continue
# print(f"action name: {action_name} sub_act: {sub_act}")
if sub_act in CAMERA_VALUE_MAP:
mouse_condition = [CAMERA_VALUE_MAP[sub_act] for _ in range(num_samples_per_action)]
elif sub_act in KEYBOARD_IDX:
col = KEYBOARD_IDX[sub_act]
for row in keyboard_condition:
row[col] = 1
data.append({"keyboard_condition": torch.tensor(keyboard_condition), "mouse_condition": torch.tensor(mouse_condition)})
return combine_data(data, num_frames, keyboard_dim=2, mouse=True)
def Bench_actions_templerun(num_frames, num_samples_per_action=4):
actions_single_action = ["jump", "slide", "leftside", "rightside", "turnleft", "turnright", "nomove"]
actions_to_test = actions_single_action
base_action = actions_single_action
KEYBOARD_IDX = {"nomove": 0, "jump": 1, "slide": 2, "turnleft": 3, "turnright": 4, "leftside": 5, "rightside": 6}
data = []
for action_name in actions_to_test:
keyboard_condition = [[0, 0, 0, 0, 0, 0, 0] for _ in range(num_samples_per_action)]
for sub_act in base_action:
if sub_act not in action_name: # 只处理action_name包含的动作
continue
# print(f"action name: {action_name} sub_act: {sub_act}")
elif sub_act in KEYBOARD_IDX:
col = KEYBOARD_IDX[sub_act]
for row in keyboard_condition:
row[col] = 1
data.append({"keyboard_condition": torch.tensor(keyboard_condition)})
return combine_data(data, num_frames, keyboard_dim=7, mouse=False)
class MatrixGame2_Bench:
def __init__(self):
self.deivce = torch.device("cuda")
self.weight_dtype = torch.bfloat16
def get_conditions(self, mode, num_frames):
conditional_dict = {}
if mode == "universal":
cond_data = Bench_actions_universal(num_frames)
mouse_condition = cond_data["mouse_condition"].unsqueeze(0).to(device=self.device, dtype=self.weight_dtype)
conditional_dict["mouse_cond"] = mouse_condition
elif mode == "gta_drive":
cond_data = Bench_actions_gta_drive(num_frames)
mouse_condition = cond_data["mouse_condition"].unsqueeze(0).to(device=self.device, dtype=self.weight_dtype)
conditional_dict["mouse_cond"] = mouse_condition
else:
cond_data = Bench_actions_templerun(num_frames)
keyboard_condition = cond_data["keyboard_condition"].unsqueeze(0).to(device=self.device, dtype=self.weight_dtype)
conditional_dict["keyboard_cond"] = keyboard_condition
return conditional_dict
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import html
import string
import ftfy
import regex as re
from transformers import AutoTokenizer
__all__ = ["HuggingfaceTokenizer"]
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
def canonicalize(text, keep_punctuation_exact_string=None):
text = text.replace("_", " ")
if keep_punctuation_exact_string:
text = keep_punctuation_exact_string.join(part.translate(str.maketrans("", "", string.punctuation)) for part in text.split(keep_punctuation_exact_string))
else:
text = text.translate(str.maketrans("", "", string.punctuation))
text = text.lower()
text = re.sub(r"\s+", " ", text)
return text.strip()
class HuggingfaceTokenizer:
def __init__(self, name, seq_len=None, clean=None, **kwargs):
assert clean in (None, "whitespace", "lower", "canonicalize")
self.name = name
self.seq_len = seq_len
self.clean = clean
# init tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
self.vocab_size = self.tokenizer.vocab_size
def __call__(self, sequence, **kwargs):
return_mask = kwargs.pop("return_mask", False)
# arguments
_kwargs = {"return_tensors": "pt"}
if self.seq_len is not None:
_kwargs.update({"padding": "max_length", "truncation": True, "max_length": self.seq_len})
_kwargs.update(**kwargs)
# tokenization
if isinstance(sequence, str):
sequence = [sequence]
if self.clean:
sequence = [self._clean(u) for u in sequence]
ids = self.tokenizer(sequence, **_kwargs)
# output
if return_mask:
return ids.input_ids, ids.attention_mask
else:
return ids.input_ids
def _clean(self, text):
if self.clean == "whitespace":
text = whitespace_clean(basic_clean(text))
elif self.clean == "lower":
text = whitespace_clean(basic_clean(text)).lower()
elif self.clean == "canonicalize":
text = canonicalize(basic_clean(text))
return text
from typing import List, Tuple, Union
import torch
def _to_tuple(x, dim=2):
if isinstance(x, int):
return (x,) * dim
elif len(x) == dim:
return x
else:
raise ValueError(f"Expected length {dim} or int, but got {x}")
def get_meshgrid_nd(start, *args, dim=2):
"""
Get n-D meshgrid with start, stop and num.
Args:
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
n-tuples.
*args: See above.
dim (int): Dimension of the meshgrid. Defaults to 2.
Returns:
grid (np.ndarray): [dim, ...]
"""
if len(args) == 0:
# start is grid_size
num = _to_tuple(start, dim=dim)
start = (0,) * dim
stop = num
elif len(args) == 1:
# start is start, args[0] is stop, step is 1
start = _to_tuple(start, dim=dim)
stop = _to_tuple(args[0], dim=dim)
num = [stop[i] - start[i] for i in range(dim)]
elif len(args) == 2:
# start is start, args[0] is stop, args[1] is num
start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
else:
raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
# PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
axis_grid = []
for i in range(dim):
a, b, n = start[i], stop[i], num[i]
g = torch.linspace(a, b, n + 1, dtype=torch.float32, device=torch.cuda.current_device())[:n]
axis_grid.append(g)
grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
grid = torch.stack(grid, dim=0) # [dim, W, H, D]
return grid
#################################################################################
# Rotary Positional Embedding Functions #
#################################################################################
# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
def reshape_for_broadcast(
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
x: torch.Tensor,
head_first=False,
):
"""
Reshape frequency tensor for broadcasting it with another tensor.
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
for the purpose of broadcasting the frequency tensor during element-wise operations.
Notes:
When using FlashMHAModified, head_first should be False.
When using Attention, head_first should be True.
Args:
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
x (torch.Tensor): Target tensor for broadcasting compatibility.
head_first (bool): head dimension first (except batch dim) or not.
Returns:
torch.Tensor: Reshaped frequency tensor.
Raises:
AssertionError: If the frequency tensor doesn't match the expected shape.
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
"""
ndim = x.ndim
assert 0 <= 1 < ndim
if isinstance(freqs_cis, tuple):
# freqs_cis: (cos, sin) in real space
if head_first:
assert freqs_cis[0].shape == (
x.shape[-2],
x.shape[-1],
), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
else:
# assert freqs_cis[0].shape == (
# x.shape[1],
# x.shape[-1],
# ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
# shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
shape = [1, freqs_cis[0].shape[0], 1, freqs_cis[0].shape[1]]
return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
else:
# freqs_cis: values in complex space
if head_first:
assert freqs_cis.shape == (
x.shape[-2],
x.shape[-1],
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
else:
assert freqs_cis.shape == (
x.shape[1],
x.shape[-1],
), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def rotate_half(x):
x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
head_first: bool = False,
start_offset: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor.
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
returned as real tensors.
Args:
xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
head_first (bool): head dimension first (except batch dim) or not.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
# print(freqs_cis[0].shape, xq.shape, xk.shape)
xk_out = None
assert isinstance(freqs_cis, tuple)
if isinstance(freqs_cis, tuple):
cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
cos, sin = cos.to(xq.device), sin.to(xq.device)
# real * cos - imag * sin
# imag * cos + real * sin
xq_out = (xq.float() * cos[:, start_offset : start_offset + xq.shape[1], :, :] + rotate_half(xq.float()) * sin[:, start_offset : start_offset + xq.shape[1], :, :]).type_as(xq)
xk_out = (xk.float() * cos[:, start_offset : start_offset + xk.shape[1], :, :] + rotate_half(xk.float()) * sin[:, start_offset : start_offset + xk.shape[1], :, :]).type_as(xk)
else:
# view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2]
freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2]
# (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
# view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2]
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
return xq_out, xk_out
def get_nd_rotary_pos_embed(
rope_dim_list,
start,
*args,
theta=10000.0,
use_real=False,
theta_rescale_factor: Union[float, List[float]] = 1.0,
interpolation_factor: Union[float, List[float]] = 1.0,
):
"""
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
Args:
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
sum(rope_dim_list) should equal to head_dim of attention layer.
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
*args: See above.
theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
part and an imaginary part separately.
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
Returns:
pos_embed (torch.Tensor): [HW, D/2]
"""
grid = get_meshgrid_nd(start, *args, dim=len(rope_dim_list)) # [3, W, H, D] / [2, W, H]
if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
assert len(theta_rescale_factor) == len(rope_dim_list), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
interpolation_factor = [interpolation_factor] * len(rope_dim_list)
elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
assert len(interpolation_factor) == len(rope_dim_list), "len(interpolation_factor) should equal to len(rope_dim_list)"
# use 1/ndim of dimensions to encode grid_axis
embs = []
for i in range(len(rope_dim_list)):
emb = get_1d_rotary_pos_embed(
rope_dim_list[i],
grid[i].reshape(-1),
theta,
use_real=use_real,
theta_rescale_factor=theta_rescale_factor[i],
interpolation_factor=interpolation_factor[i],
) # 2 x [WHD, rope_dim_list[i]]
embs.append(emb)
if use_real:
cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
return cos, sin
else:
emb = torch.cat(embs, dim=1) # (WHD, D/2)
return emb
def get_1d_rotary_pos_embed(
dim: int,
pos: Union[torch.FloatTensor, int],
theta: float = 10000.0,
use_real: bool = False,
theta_rescale_factor: float = 1.0,
interpolation_factor: float = 1.0,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Precompute the frequency tensor for complex exponential (cis) with given dimensions.
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.
Args:
dim (int): Dimension of the frequency tensor.
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
use_real (bool, optional): If True, return real part and imaginary part separately.
Otherwise, return complex numbers.
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
Returns:
freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
"""
if isinstance(pos, int):
pos = torch.arange(pos, device=torch.cuda.current_device()).float()
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
if theta_rescale_factor != 1.0:
theta *= theta_rescale_factor ** (dim / (dim - 2))
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, device=torch.cuda.current_device())[: (dim // 2)].float() / dim)) # [D/2]
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
if use_real:
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
return freqs_cos, freqs_sin
else:
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
return freqs_cis
import torch
from lightx2v.models.networks.wan.infer.module_io import GridOutput, WanPreInferModuleOutput
from lightx2v.models.networks.wan.infer.self_forcing.pre_infer import WanSFPreInfer, sinusoidal_embedding_1d
from lightx2v.utils.envs import *
def cond_current(conditional_dict, current_start_frame, num_frame_per_block, replace=None, mode="universal"):
new_cond = {}
new_cond["cond_concat"] = conditional_dict["image_encoder_output"]["cond_concat"][:, :, current_start_frame : current_start_frame + num_frame_per_block]
new_cond["visual_context"] = conditional_dict["image_encoder_output"]["visual_context"]
if replace:
if current_start_frame == 0:
last_frame_num = 1 + 4 * (num_frame_per_block - 1)
else:
last_frame_num = 4 * num_frame_per_block
final_frame = 1 + 4 * (current_start_frame + num_frame_per_block - 1)
if mode != "templerun":
conditional_dict["text_encoder_output"]["mouse_cond"][:, -last_frame_num + final_frame : final_frame] = replace["mouse"][None, None, :].repeat(1, last_frame_num, 1)
conditional_dict["text_encoder_output"]["keyboard_cond"][:, -last_frame_num + final_frame : final_frame] = replace["keyboard"][None, None, :].repeat(1, last_frame_num, 1)
if mode != "templerun":
new_cond["mouse_cond"] = conditional_dict["text_encoder_output"]["mouse_cond"][:, : 1 + 4 * (current_start_frame + num_frame_per_block - 1)]
new_cond["keyboard_cond"] = conditional_dict["text_encoder_output"]["keyboard_cond"][:, : 1 + 4 * (current_start_frame + num_frame_per_block - 1)]
if replace:
return new_cond, conditional_dict
else:
return new_cond
# @amp.autocast(enabled=False)
def rope_params(max_seq_len, dim, theta=10000):
assert dim % 2 == 0
freqs = torch.outer(torch.arange(max_seq_len), 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)))
freqs = torch.polar(torch.ones_like(freqs), freqs)
return freqs
class WanMtxg2PreInfer(WanSFPreInfer):
def __init__(self, config):
super().__init__(config)
d = config["dim"] // config["num_heads"]
self.freqs = torch.cat([rope_params(1024, d - 4 * (d // 6)), rope_params(1024, 2 * (d // 6)), rope_params(1024, 2 * (d // 6))], dim=1).to(torch.device("cuda"))
self.dim = config["dim"]
def img_emb(self, weights, x):
x = weights.img_emb_0.apply(x)
x = weights.img_emb_1.apply(x.squeeze(0))
x = torch.nn.functional.gelu(x, approximate="none")
x = weights.img_emb_3.apply(x)
x = weights.img_emb_4.apply(x)
x = x.unsqueeze(0)
return x
@torch.no_grad()
def infer(self, weights, inputs, kv_start=0, kv_end=0):
x = self.scheduler.latents_input
t = self.scheduler.timestep_input
current_start_frame = self.scheduler.seg_index * self.scheduler.num_frame_per_block
if self.config["streaming"]:
current_actions = inputs["current_actions"]
current_conditional_dict, _ = cond_current(inputs, current_start_frame, self.scheduler.num_frame_per_block, replace=current_actions, mode=self.config["mode"])
else:
current_conditional_dict = cond_current(inputs, current_start_frame, self.scheduler.num_frame_per_block, mode=self.config["mode"])
cond_concat = current_conditional_dict["cond_concat"]
visual_context = current_conditional_dict["visual_context"]
x = torch.cat([x.unsqueeze(0), cond_concat], dim=1)
# embeddings
x = weights.patch_embedding.apply(x)
grid_sizes_t, grid_sizes_h, grid_sizes_w = torch.tensor(x.shape[2:], dtype=torch.long)
grid_sizes = GridOutput(tensor=torch.tensor([[grid_sizes_t, grid_sizes_h, grid_sizes_w]], dtype=torch.int32, device=x.device), tuple=(grid_sizes_t, grid_sizes_h, grid_sizes_w))
x = x.flatten(2).transpose(1, 2) # B FHW C'
seq_lens = torch.tensor([u.size(0) for u in x], dtype=torch.long, device=torch.device("cuda"))
assert seq_lens[0] <= 15 * 1 * 880
embed_tmp = sinusoidal_embedding_1d(self.freq_dim, t.flatten()).type_as(x) # torch.Size([3, 256])
embed = self.time_embedding(weights, embed_tmp) # torch.Size([3, 1536])
embed0 = self.time_projection(weights, embed).unflatten(dim=0, sizes=t.shape)
# context
context_lens = None
context = self.img_emb(weights, visual_context)
return WanPreInferModuleOutput(
embed=embed,
grid_sizes=grid_sizes,
x=x.squeeze(0),
embed0=embed0.squeeze(0),
seq_lens=seq_lens,
freqs=self.freqs,
context=context[0],
conditional_dict=current_conditional_dict,
)
import math
import torch
from einops import rearrange
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ImportError:
from flash_attn import flash_attn_func
FLASH_ATTN_3_AVAILABLE = False
from lightx2v.models.networks.wan.infer.matrix_game2.posemb_layers import apply_rotary_emb, get_nd_rotary_pos_embed
from lightx2v.models.networks.wan.infer.self_forcing.transformer_infer import WanSFTransformerInfer, causal_rope_apply
class WanMtxg2TransformerInfer(WanSFTransformerInfer):
def __init__(self, config):
super().__init__(config)
self._initialize_kv_cache_mouse_and_keyboard(self.device, self.dtype)
self.sink_size = 0
self.vae_time_compression_ratio = config["action_config"]["vae_time_compression_ratio"]
self.windows_size = config["action_config"]["windows_size"]
self.patch_size = config["action_config"]["patch_size"]
self.rope_theta = config["action_config"]["rope_theta"]
self.enable_keyboard = config["action_config"]["enable_keyboard"]
self.heads_num = config["action_config"]["heads_num"]
self.hidden_size = config["action_config"]["hidden_size"]
self.img_hidden_size = config["action_config"]["img_hidden_size"]
self.keyboard_dim_in = config["action_config"]["keyboard_dim_in"]
self.keyboard_hidden_dim = config["action_config"]["keyboard_hidden_dim"]
self.qk_norm = config["action_config"]["qk_norm"]
self.qkv_bias = config["action_config"]["qkv_bias"]
self.rope_dim_list = config["action_config"]["rope_dim_list"]
self.freqs_cos, self.freqs_sin = self.get_rotary_pos_embed(7500, self.patch_size[1], self.patch_size[2], 64, self.rope_dim_list, start_offset=0)
self.enable_mouse = config["action_config"]["enable_mouse"]
if self.enable_mouse:
self.mouse_dim_in = config["action_config"]["mouse_dim_in"]
self.mouse_hidden_dim = config["action_config"]["mouse_hidden_dim"]
self.mouse_qk_dim_list = config["action_config"]["mouse_qk_dim_list"]
def get_rotary_pos_embed(self, video_length, height, width, head_dim, rope_dim_list=None, start_offset=0):
target_ndim = 3
ndim = 5 - 2
latents_size = [video_length + start_offset, height, width]
if isinstance(self.patch_size, int):
assert all(s % self.patch_size == 0 for s in latents_size), f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.patch_size}), but got {latents_size}."
rope_sizes = [s // self.patch_size for s in latents_size]
elif isinstance(self.patch_size, list):
assert all(s % self.patch_size[idx] == 0 for idx, s in enumerate(latents_size)), (
f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.patch_size}), but got {latents_size}."
)
rope_sizes = [s // self.patch_size[idx] for idx, s in enumerate(latents_size)]
if len(rope_sizes) != target_ndim:
rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
if rope_dim_list is None:
rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
assert sum(rope_dim_list) == head_dim, "sum(rope_dim_list) should equal to head_dim of attention layer"
freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
rope_dim_list,
rope_sizes,
theta=self.rope_theta,
use_real=True,
theta_rescale_factor=1,
)
return freqs_cos[-video_length * rope_sizes[1] * rope_sizes[2] // self.patch_size[0] :], freqs_sin[-video_length * rope_sizes[1] * rope_sizes[2] // self.patch_size[0] :]
def _initialize_kv_cache(self, dtype, device):
"""
Initialize a Per-GPU KV cache for the Wan model.
"""
kv_cache1 = []
if self.local_attn_size != -1:
# Use the local attention size to compute the KV cache size
kv_cache_size = self.local_attn_size * self.frame_seq_length
else:
# Use the default KV cache size
kv_cache_size = 32760
for _ in range(self.num_transformer_blocks):
kv_cache1.append(
{
"k": torch.zeros((kv_cache_size, 12, 128)).to(dtype).to(device),
"v": torch.zeros((kv_cache_size, 12, 128)).to(dtype).to(device),
"global_end_index": 0,
"local_end_index": 0,
}
)
self.kv_cache1_default = kv_cache1
def _initialize_kv_cache_mouse_and_keyboard(self, device, dtype):
"""
Initialize a Per-GPU KV cache for the Wan model.
"""
kv_cache_mouse = []
kv_cache_keyboard = []
if self.local_attn_size != -1:
kv_cache_size = self.local_attn_size
else:
kv_cache_size = 15 * 1
for _ in range(self.num_transformer_blocks):
kv_cache_keyboard.append(
{
"k": torch.zeros([1, kv_cache_size, 16, 64], dtype=dtype, device=device),
"v": torch.zeros([1, kv_cache_size, 16, 64], dtype=dtype, device=device),
"global_end_index": torch.tensor([0], dtype=torch.long, device=device),
"local_end_index": torch.tensor([0], dtype=torch.long, device=device),
}
)
kv_cache_mouse.append(
{
"k": torch.zeros([self.frame_seq_length, kv_cache_size, 16, 64], dtype=dtype, device=device),
"v": torch.zeros([self.frame_seq_length, kv_cache_size, 16, 64], dtype=dtype, device=device),
"global_end_index": torch.tensor([0], dtype=torch.long, device=device),
"local_end_index": torch.tensor([0], dtype=torch.long, device=device),
}
)
self.kv_cache_keyboard = kv_cache_keyboard
self.kv_cache_mouse = kv_cache_mouse
def infer_self_attn_with_kvcache(self, phase, grid_sizes, x, seq_lens, freqs, shift_msa, scale_msa):
if hasattr(phase, "smooth_norm1_weight"):
norm1_weight = (1 + scale_msa.squeeze()) * phase.smooth_norm1_weight.tensor
norm1_bias = shift_msa.squeeze() * phase.smooth_norm1_bias.tensor
else:
norm1_weight = 1 + scale_msa.squeeze()
norm1_bias = shift_msa.squeeze()
norm1_out = phase.norm1.apply(x)
if self.sensitive_layer_dtype != self.infer_dtype:
norm1_out = norm1_out.to(self.sensitive_layer_dtype)
norm1_out.mul_(norm1_weight[0:1, :]).add_(norm1_bias[0:1, :])
if self.sensitive_layer_dtype != self.infer_dtype: # False
norm1_out = norm1_out.to(self.infer_dtype)
s, n, d = *norm1_out.shape[:1], self.num_heads, self.head_dim
q0 = phase.self_attn_q.apply(norm1_out)
k0 = phase.self_attn_k.apply(norm1_out)
q = phase.self_attn_norm_q.apply(q0).view(s, n, d)
k = phase.self_attn_norm_k.apply(k0).view(s, n, d)
v = phase.self_attn_v.apply(norm1_out).view(s, n, d)
seg_index = self.scheduler.seg_index
frame_seqlen = math.prod(grid_sizes[0][1:]).item()
current_start = seg_index * self.num_frame_per_block * self.frame_seq_length
current_start_frame = current_start // frame_seqlen
q = causal_rope_apply(q.unsqueeze(0), grid_sizes, freqs, start_frame=current_start_frame).type_as(v)[0]
k = causal_rope_apply(k.unsqueeze(0), grid_sizes, freqs, start_frame=current_start_frame).type_as(v)[0]
current_end = current_start + q.shape[0]
sink_tokens = self.sink_size * frame_seqlen
kv_cache_size = self.kv_cache1[self.block_idx]["k"].shape[0]
num_new_tokens = q.shape[0]
if (current_end > self.kv_cache1[self.block_idx]["global_end_index"]) and (num_new_tokens + self.kv_cache1[self.block_idx]["local_end_index"] > kv_cache_size):
num_evicted_tokens = num_new_tokens + self.kv_cache1[self.block_idx]["local_end_index"] - kv_cache_size
num_rolled_tokens = self.kv_cache1[self.block_idx]["local_end_index"] - num_evicted_tokens - sink_tokens
self.kv_cache1[self.block_idx]["k"][sink_tokens : sink_tokens + num_rolled_tokens] = self.kv_cache1[self.block_idx]["k"][
sink_tokens + num_evicted_tokens : sink_tokens + num_evicted_tokens + num_rolled_tokens
].clone()
self.kv_cache1[self.block_idx]["v"][sink_tokens : sink_tokens + num_rolled_tokens] = self.kv_cache1[self.block_idx]["v"][
sink_tokens + num_evicted_tokens : sink_tokens + num_evicted_tokens + num_rolled_tokens
].clone()
# Insert the new keys/values at the end
local_end_index = self.kv_cache1[self.block_idx]["local_end_index"] + current_end - self.kv_cache1[self.block_idx]["global_end_index"] - num_evicted_tokens
local_start_index = local_end_index - num_new_tokens
self.kv_cache1[self.block_idx]["k"][local_start_index:local_end_index] = k
self.kv_cache1[self.block_idx]["v"][local_start_index:local_end_index] = v
else:
# Assign new keys/values directly up to current_end
local_end_index = self.kv_cache1[self.block_idx]["local_end_index"] + current_end - self.kv_cache1[self.block_idx]["global_end_index"]
local_start_index = local_end_index - num_new_tokens
self.kv_cache1[self.block_idx]["k"][local_start_index:local_end_index] = k
self.kv_cache1[self.block_idx]["v"][local_start_index:local_end_index] = v
attn_k = self.kv_cache1[self.block_idx]["k"][max(0, local_end_index - self.max_attention_size) : local_end_index]
attn_v = self.kv_cache1[self.block_idx]["v"][max(0, local_end_index - self.max_attention_size) : local_end_index]
self.kv_cache1[self.block_idx]["local_end_index"] = local_end_index
self.kv_cache1[self.block_idx]["global_end_index"] = current_end
k_lens = torch.empty_like(seq_lens).fill_(attn_k.size(0))
cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(q, k_lens=k_lens)
if self.clean_cuda_cache:
del freqs_i, norm1_out, norm1_weight, norm1_bias
torch.cuda.empty_cache()
if self.config["seq_parallel"]:
attn_out = phase.self_attn_1_parallel.apply(
q=q,
k=attn_k,
v=attn_v,
img_qkv_len=q.shape[0],
cu_seqlens_qkv=cu_seqlens_q,
attention_module=phase.self_attn_1,
seq_p_group=self.seq_p_group,
model_cls=self.config["model_cls"],
)
else:
attn_out = phase.self_attn_1.apply(
q=q,
k=attn_k,
v=attn_v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_k,
max_seqlen_q=q.size(0),
max_seqlen_kv=attn_k.size(0),
model_cls=self.config["model_cls"],
)
y = phase.self_attn_o.apply(attn_out)
if self.clean_cuda_cache:
del q, k, v, attn_out
torch.cuda.empty_cache()
return y
def infer_cross_attn_with_kvcache(self, phase, x, context, y_out, gate_msa):
num_frames = gate_msa.shape[0]
frame_seqlen = x.shape[0] // gate_msa.shape[0]
x.add_((y_out.unflatten(dim=0, sizes=(num_frames, frame_seqlen)) * gate_msa).flatten(0, 1))
norm3_out = phase.norm3.apply(x)
n, d = self.num_heads, self.head_dim
q = phase.cross_attn_q.apply(norm3_out)
q = phase.cross_attn_norm_q.apply(q).view(-1, n, d)
if not self.crossattn_cache[self.block_idx]["is_init"]:
self.crossattn_cache[self.block_idx]["is_init"] = True
k = phase.cross_attn_k.apply(context)
k = phase.cross_attn_norm_k.apply(k).view(-1, n, d)
v = phase.cross_attn_v.apply(context)
v = v.view(-1, n, d)
self.crossattn_cache[self.block_idx]["k"] = k
self.crossattn_cache[self.block_idx]["v"] = v
else:
k = self.crossattn_cache[self.block_idx]["k"]
v = self.crossattn_cache[self.block_idx]["v"]
cu_seqlens_q, cu_seqlens_k = self._calculate_q_k_len(
q,
k_lens=torch.tensor([k.size(0)], dtype=torch.int32, device=k.device),
)
attn_out = phase.cross_attn_1.apply(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_k,
max_seqlen_q=q.size(0),
max_seqlen_kv=k.size(0),
model_cls=self.config["model_cls"],
)
attn_out = phase.cross_attn_o.apply(attn_out)
if self.clean_cuda_cache:
del q, k, v, norm3_out, context, context_img
torch.cuda.empty_cache()
return x, attn_out
def infer_action_model(self, phase, x, grid_sizes, seq_lens, mouse_condition=None, keyboard_condition=None, is_causal=False, use_rope_keyboard=True):
tt, th, tw = grid_sizes
current_start = self.scheduler.seg_index * self.num_frame_per_block
start_frame = current_start
B, N_frames, C = keyboard_condition.shape
assert tt * th * tw == x.shape[0]
assert ((N_frames - 1) + self.vae_time_compression_ratio) % self.vae_time_compression_ratio == 0
N_feats = int((N_frames - 1) / self.vae_time_compression_ratio) + 1
# Defined freqs_cis early so it's available for both mouse and keyboard
freqs_cis = (self.freqs_cos, self.freqs_sin)
cond1 = N_feats == tt
cond2 = is_causal and not self.kv_cache_mouse
cond3 = (N_frames - 1) // self.vae_time_compression_ratio + 1 == current_start + self.num_frame_per_block
assert (cond1 and ((cond2) or not is_causal)) or (cond3 and is_causal)
x = x.unsqueeze(0)
if self.enable_mouse and mouse_condition is not None:
hidden_states = rearrange(x, "B (T S) C -> (B S) T C", T=tt, S=th * tw) # 65*272*480 -> 17*(272//16)*(480//16) -> 8670
B, N_frames, C = mouse_condition.shape
else:
hidden_states = x
pad_t = self.vae_time_compression_ratio * self.windows_size
if self.enable_mouse and mouse_condition is not None:
pad = mouse_condition[:, 0:1, :].expand(-1, pad_t, -1)
mouse_condition = torch.cat([pad, mouse_condition], dim=1)
if is_causal and self.kv_cache_mouse is not None:
mouse_condition = mouse_condition[:, self.vae_time_compression_ratio * (N_feats - self.num_frame_per_block - self.windows_size) + pad_t :, :]
group_mouse = [
mouse_condition[:, self.vae_time_compression_ratio * (i - self.windows_size) + pad_t : i * self.vae_time_compression_ratio + pad_t, :] for i in range(self.num_frame_per_block)
]
else:
group_mouse = [mouse_condition[:, self.vae_time_compression_ratio * (i - self.windows_size) + pad_t : i * self.vae_time_compression_ratio + pad_t, :] for i in range(N_feats)]
group_mouse = torch.stack(group_mouse, dim=1)
S = th * tw
group_mouse = group_mouse.unsqueeze(-1).expand(B, self.num_frame_per_block, pad_t, C, S)
group_mouse = group_mouse.permute(0, 4, 1, 2, 3).reshape(B * S, self.num_frame_per_block, pad_t * C)
group_mouse = torch.cat([hidden_states, group_mouse], dim=-1)
# mouse_mlp
# 注释:Batch维度不可避免,因此用 torch.nn.functional
group_mouse = torch.nn.functional.linear(group_mouse, phase.mouse_mlp_0.weight.T, phase.mouse_mlp_0.bias)
group_mouse = torch.nn.functional.gelu(group_mouse, approximate="tanh")
group_mouse = torch.nn.functional.linear(group_mouse, phase.mouse_mlp_2.weight.T, phase.mouse_mlp_2.bias)
group_mouse = torch.nn.functional.layer_norm(group_mouse, (group_mouse.shape[-1],), phase.mouse_mlp_3.weight.T, phase.mouse_mlp_3.bias, 1e-5)
# qkvc
mouse_qkv = torch.nn.functional.linear(group_mouse, phase.t_qkv.weight.T)
q0, k0, v = rearrange(mouse_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) # BHW F H C # torch.Size([880, 3, 16, 64])
q = q0 * torch.rsqrt(q0.pow(2).mean(dim=-1, keepdim=True) + 1e-6)
k = k0 * torch.rsqrt(k0.pow(2).mean(dim=-1, keepdim=True) + 1e-6)
q, k = apply_rotary_emb(q, k, freqs_cis, start_offset=start_frame, head_first=False)
## TODO: adding cache here
if is_causal:
current_end = current_start + q.shape[1]
assert q.shape[1] == self.num_frame_per_block
sink_size = 0
max_attention_size = self.local_attn_size
sink_tokens = sink_size * 1
kv_cache_size = self.kv_cache_mouse[self.block_idx]["k"].shape[1]
num_new_tokens = q.shape[1]
if (current_end > self.kv_cache_mouse[self.block_idx]["global_end_index"].item()) and (num_new_tokens + self.kv_cache_mouse[self.block_idx]["local_end_index"].item() > kv_cache_size):
num_evicted_tokens = num_new_tokens + self.kv_cache_mouse[self.block_idx]["local_end_index"].item() - kv_cache_size
num_rolled_tokens = self.kv_cache_mouse[self.block_idx]["local_end_index"].item() - num_evicted_tokens - sink_tokens
self.kv_cache_mouse[self.block_idx]["k"][:, sink_tokens : sink_tokens + num_rolled_tokens] = self.kv_cache_mouse[self.block_idx]["k"][
:, sink_tokens + num_evicted_tokens : sink_tokens + num_evicted_tokens + num_rolled_tokens
].clone()
self.kv_cache_mouse[self.block_idx]["v"][:, sink_tokens : sink_tokens + num_rolled_tokens] = self.kv_cache_mouse[self.block_idx]["v"][
:, sink_tokens + num_evicted_tokens : sink_tokens + num_evicted_tokens + num_rolled_tokens
].clone()
# Insert the new keys/values at the end
local_end_index = self.kv_cache_mouse[self.block_idx]["local_end_index"].item() + current_end - self.kv_cache_mouse[self.block_idx]["global_end_index"].item() - num_evicted_tokens
local_start_index = local_end_index - num_new_tokens
else:
local_end_index = self.kv_cache_mouse[self.block_idx]["local_end_index"].item() + current_end - self.kv_cache_mouse[self.block_idx]["global_end_index"].item()
local_start_index = local_end_index - num_new_tokens
self.kv_cache_mouse[self.block_idx]["k"][:, local_start_index:local_end_index] = k
self.kv_cache_mouse[self.block_idx]["v"][:, local_start_index:local_end_index] = v
attn_k = self.kv_cache_mouse[self.block_idx]["k"][:, max(0, local_end_index - max_attention_size) : local_end_index]
attn_v = self.kv_cache_mouse[self.block_idx]["v"][:, max(0, local_end_index - max_attention_size) : local_end_index]
attn = flash_attn_interface.flash_attn_func(
q,
attn_k,
attn_v,
)
self.kv_cache_mouse[self.block_idx]["global_end_index"].fill_(current_end)
self.kv_cache_mouse[self.block_idx]["local_end_index"].fill_(local_end_index)
else:
attn = flash_attn_func(
q,
k,
v,
)
# Compute cu_squlens and max_seqlen for flash attention
# qk norm
attn = rearrange(attn, "(b S) T h d -> b (T S) (h d)", b=B)
hidden_states = rearrange(x, "(B S) T C -> B (T S) C", B=B)
attn = phase.proj_mouse.apply(attn[0]).unsqueeze(0)
hidden_states = hidden_states + attn
if self.enable_keyboard and keyboard_condition is not None:
pad = keyboard_condition[:, 0:1, :].expand(-1, pad_t, -1)
keyboard_condition = torch.cat([pad, keyboard_condition], dim=1)
if is_causal and self.kv_cache_keyboard is not None:
keyboard_condition = keyboard_condition[
:, self.vae_time_compression_ratio * (N_feats - self.num_frame_per_block - self.windows_size) + pad_t :, :
] # keyboard_condition[:, self.vae_time_compression_ratio*(start_frame - self.windows_size) + pad_t:start_frame * self.vae_time_compression_ratio + pad_t,:]
keyboard_condition = phase.keyboard_embed_0.apply(keyboard_condition[0])
keyboard_condition = torch.nn.functional.silu(keyboard_condition)
keyboard_condition = phase.keyboard_embed_2.apply(keyboard_condition).unsqueeze(0)
group_keyboard = [
keyboard_condition[:, self.vae_time_compression_ratio * (i - self.windows_size) + pad_t : i * self.vae_time_compression_ratio + pad_t, :] for i in range(self.num_frame_per_block)
]
else:
keyboard_condition = phase.keyboard_embed_0.apply(keyboard_condition[0])
keyboard_condition = torch.nn.functional.silu(keyboard_condition)
keyboard_condition = phase.keyboard_embed_2.apply(keyboard_condition).unsqueeze(0)
group_keyboard = [keyboard_condition[:, self.vae_time_compression_ratio * (i - self.windows_size) + pad_t : i * self.vae_time_compression_ratio + pad_t, :] for i in range(N_feats)]
group_keyboard = torch.stack(group_keyboard, dim=1) # B F RW C
group_keyboard = group_keyboard.reshape(shape=(group_keyboard.shape[0], group_keyboard.shape[1], -1))
# apply cross attn
mouse_q = phase.mouse_attn_q.apply(hidden_states[0]).unsqueeze(0)
keyboard_kv = phase.keyboard_attn_kv.apply(group_keyboard[0]).unsqueeze(0)
B, L, HD = mouse_q.shape
D = HD // self.heads_num
q = mouse_q.view(B, L, self.heads_num, D)
B, L, KHD = keyboard_kv.shape
k, v = keyboard_kv.view(B, L, 2, self.heads_num, D).permute(2, 0, 1, 3, 4)
# Compute cu_squlens and max_seqlen for flash attention
# qk norm
q = q * torch.rsqrt(q.pow(2).mean(dim=-1, keepdim=True) + 1e-6)
k = k * torch.rsqrt(k.pow(2).mean(dim=-1, keepdim=True) + 1e-6)
S = th * tw
assert S == 880
# position embed
if use_rope_keyboard:
B, TS, H, D = q.shape
T_ = TS // S
q = q.view(B, T_, S, H, D).transpose(1, 2).reshape(B * S, T_, H, D)
q, k = apply_rotary_emb(q, k, freqs_cis, start_offset=start_frame, head_first=False)
k1, k2, k3, k4 = k.shape
k = k.expand(S, k2, k3, k4)
v = v.expand(S, k2, k3, k4)
if is_causal:
current_end = current_start + k.shape[1]
assert k.shape[1] == self.num_frame_per_block
sink_size = 0
max_attention_size = self.local_attn_size
sink_tokens = sink_size * 1
kv_cache_size = self.kv_cache_keyboard[self.block_idx]["k"].shape[1]
num_new_tokens = k.shape[1]
if (current_end > self.kv_cache_keyboard[self.block_idx]["global_end_index"].item()) and (
num_new_tokens + self.kv_cache_keyboard[self.block_idx]["local_end_index"].item() > kv_cache_size
):
num_evicted_tokens = num_new_tokens + self.kv_cache_keyboard[self.block_idx]["local_end_index"].item() - kv_cache_size
num_rolled_tokens = self.kv_cache_keyboard[self.block_idx]["local_end_index"].item() - num_evicted_tokens - sink_tokens
self.kv_cache_keyboard[self.block_idx]["k"][:, sink_tokens : sink_tokens + num_rolled_tokens] = self.kv_cache_keyboard[self.block_idx]["k"][
:, sink_tokens + num_evicted_tokens : sink_tokens + num_evicted_tokens + num_rolled_tokens
].clone()
self.kv_cache_keyboard[self.block_idx]["v"][:, sink_tokens : sink_tokens + num_rolled_tokens] = self.kv_cache_keyboard[self.block_idx]["v"][
:, sink_tokens + num_evicted_tokens : sink_tokens + num_evicted_tokens + num_rolled_tokens
].clone()
# Insert the new keys/values at the end
local_end_index = (
self.kv_cache_keyboard[self.block_idx]["local_end_index"].item() + current_end - self.kv_cache_keyboard[self.block_idx]["global_end_index"].item() - num_evicted_tokens
)
local_start_index = local_end_index - num_new_tokens
else:
local_end_index = self.kv_cache_keyboard[self.block_idx]["local_end_index"].item() + current_end - self.kv_cache_keyboard[self.block_idx]["global_end_index"].item()
local_start_index = local_end_index - num_new_tokens
assert k.shape[0] == 880 # BS == 1 or the cache should not be saved/ load method should be modified
self.kv_cache_keyboard[self.block_idx]["k"][:, local_start_index:local_end_index] = k[:1]
self.kv_cache_keyboard[self.block_idx]["v"][:, local_start_index:local_end_index] = v[:1]
if FLASH_ATTN_3_AVAILABLE:
attn_k = self.kv_cache_keyboard[self.block_idx]["k"][:, max(0, local_end_index - max_attention_size) : local_end_index].repeat(S, 1, 1, 1)
attn_v = self.kv_cache_keyboard[self.block_idx]["v"][:, max(0, local_end_index - max_attention_size) : local_end_index].repeat(S, 1, 1, 1)
attn = flash_attn_interface.flash_attn_func(
q,
attn_k,
attn_v,
)
else:
attn = flash_attn_func(
q,
self.kv_cache_keyboard[self.block_idx]["k"][max(0, local_end_index - max_attention_size) : local_end_index].repeat(S, 1, 1, 1),
self.kv_cache_keyboard[self.block_idx]["v"][max(0, local_end_index - max_attention_size) : local_end_index].repeat(S, 1, 1, 1),
)
self.kv_cache_keyboard[self.block_idx]["global_end_index"].fill_(current_end)
self.kv_cache_keyboard[self.block_idx]["local_end_index"].fill_(local_end_index)
else:
attn = flash_attn_func(
q,
k,
v,
causal=False,
)
attn = rearrange(attn, "(B S) T H D -> B (T S) (H D)", S=S)
else:
if is_causal:
current_start = start_frame
current_end = current_start + k.shape[1]
assert k.shape[1] == self.num_frame_per_block
sink_size = 0
local_attn_size = self.local_attn_size
max_attention_size = self.local_attn_size
sink_tokens = sink_size * 1
kv_cache_size = self.kv_cache_keyboard[self.block_idx]["k"].shape[1]
num_new_tokens = k.shape[1]
if (current_end > self.kv_cache_keyboard[self.block_idx]["global_end_index"].item()) and (
num_new_tokens + self.kv_cache_keyboard[self.block_idx]["local_end_index"].item() > kv_cache_size
):
num_evicted_tokens = num_new_tokens + self.kv_cache_keyboard[self.block_idx]["local_end_index"].item() - kv_cache_size
num_rolled_tokens = self.kv_cache_keyboard[self.block_idx]["local_end_index"].item() - num_evicted_tokens - sink_tokens
self.kv_cache_keyboard[self.block_idx]["k"][:, sink_tokens : sink_tokens + num_rolled_tokens] = self.kv_cache_keyboard[self.block_idx]["k"][
:, sink_tokens + num_evicted_tokens : sink_tokens + num_evicted_tokens + num_rolled_tokens
].clone()
self.kv_cache_keyboard[self.block_idx]["v"][:, sink_tokens : sink_tokens + num_rolled_tokens] = self.kv_cache_keyboard[self.block_idx]["v"][
:, sink_tokens + num_evicted_tokens : sink_tokens + num_evicted_tokens + num_rolled_tokens
].clone()
# Insert the new keys/values at the end
local_end_index = (
self.kv_cache_keyboard[self.block_idx]["local_end_index"].item() + current_end - self.kv_cache_keyboard[self.block_idx]["global_end_index"].item() - num_evicted_tokens
)
local_start_index = local_end_index - num_new_tokens
else:
local_end_index = self.kv_cache_keyboard[self.block_idx]["local_end_index"].item() + current_end - self.kv_cache_keyboard[self.block_idx]["global_end_index"].item()
local_start_index = local_end_index - num_new_tokens
self.kv_cache_keyboard[self.block_idx]["k"][:, local_start_index:local_end_index] = k
self.kv_cache_keyboard[self.block_idx]["v"][:, local_start_index:local_end_index] = v
attn = flash_attn_func(
q,
self.kv_cache_keyboard[self.block_idx]["k"][:, max(0, local_end_index - max_attention_size) : local_end_index],
self.kv_cache_keyboard[self.block_idx]["v"][:, max(0, local_end_index - max_attention_size) : local_end_index],
)
self.kv_cache_keyboard[self.block_idx]["global_end_index"].fill_(current_end)
self.kv_cache_keyboard[self.block_idx]["local_end_index"].fill_(local_end_index)
else:
attn = flash_attn_func(
q,
k,
v,
)
attn = rearrange(attn, "B L H D -> B L (H D)")
attn = phase.proj_keyboard.apply(attn[0]).unsqueeze(0)
hidden_states = hidden_states + attn
hidden_states = hidden_states.squeeze(0)
return hidden_states
def infer_ffn(self, phase, x, c_shift_msa, c_scale_msa):
num_frames = c_shift_msa.shape[0]
frame_seqlen = x.shape[0] // c_shift_msa.shape[0]
x = phase.norm2.apply(x).unsqueeze(0)
x = x.unflatten(dim=1, sizes=(num_frames, frame_seqlen))
c_scale_msa = c_scale_msa.unsqueeze(0)
c_shift_msa = c_shift_msa.unsqueeze(0)
x = x * (1 + c_scale_msa) + c_shift_msa
x = x.flatten(1, 2).squeeze(0)
y = phase.ffn_0.apply(x)
y = torch.nn.functional.gelu(y, approximate="tanh")
y = phase.ffn_2.apply(y)
return y
def post_process(self, x, y, c_gate_msa, pre_infer_out=None):
x = x + y * c_gate_msa[0]
x = x.squeeze(0)
return x
def infer_block_witch_kvcache(self, block, x, pre_infer_out):
if hasattr(block.compute_phases[0], "before_proj"):
x = block.compute_phases[0].before_proj.apply(x) + pre_infer_out.x
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = self.pre_process(
block.compute_phases[0].modulation,
pre_infer_out.embed0,
)
y_out = self.infer_self_attn_with_kvcache(
block.compute_phases[0],
pre_infer_out.grid_sizes.tensor,
x,
pre_infer_out.seq_lens,
pre_infer_out.freqs,
shift_msa,
scale_msa,
)
x, attn_out = self.infer_cross_attn_with_kvcache(
block.compute_phases[1],
x,
pre_infer_out.context,
y_out,
gate_msa,
)
x = x + attn_out
if len(block.compute_phases) == 4:
if self.config["mode"] != "templerun":
x = self.infer_action_model(
phase=block.compute_phases[2],
x=x,
grid_sizes=pre_infer_out.grid_sizes.tensor[0],
seq_lens=pre_infer_out.seq_lens,
mouse_condition=pre_infer_out.conditional_dict["mouse_cond"],
keyboard_condition=pre_infer_out.conditional_dict["keyboard_cond"],
is_causal=True,
use_rope_keyboard=True,
)
else:
x = self.infer_action_model(
phase=block.compute_phases[2],
x=x,
grid_sizes=pre_infer_out.grid_sizes.tensor[0],
seq_lens=pre_infer_out.seq_lens,
keyboard_condition=pre_infer_out.conditional_dict["keyboard_cond"],
is_causal=True,
use_rope_keyboard=True,
)
y = self.infer_ffn(block.compute_phases[3], x, c_shift_msa, c_scale_msa)
elif len(block.compute_phases) == 3:
y = self.infer_ffn(block.compute_phases[2], x, c_shift_msa, c_scale_msa)
x = self.post_process(x, y, c_gate_msa, pre_infer_out)
return x
def infer_non_blocks(self, weights, x, e):
num_frames = e.shape[0]
frame_seqlen = x.shape[0] // e.shape[0]
e = e.unsqueeze(0).unsqueeze(2)
x = weights.norm.apply(x).unsqueeze(0)
x = x.unflatten(dim=1, sizes=(num_frames, frame_seqlen))
modulation = weights.head_modulation.tensor
e = (modulation.unsqueeze(1) + e).chunk(2, dim=2)
x = x * (1 + e[1]) + e[0]
x = torch.nn.functional.linear(x, weights.head.weight.T, weights.head.bias)
if self.clean_cuda_cache:
del e
torch.cuda.empty_cache()
return x
...@@ -20,3 +20,4 @@ class WanPreInferModuleOutput: ...@@ -20,3 +20,4 @@ class WanPreInferModuleOutput:
freqs: torch.Tensor freqs: torch.Tensor
context: torch.Tensor context: torch.Tensor
adapter_args: Dict[str, Any] = field(default_factory=dict) adapter_args: Dict[str, Any] = field(default_factory=dict)
conditional_dict: Dict[str, Any] = field(default_factory=dict)
import json
import os
import torch
from safetensors import safe_open
from lightx2v.models.networks.wan.infer.matrix_game2.pre_infer import WanMtxg2PreInfer
from lightx2v.models.networks.wan.infer.matrix_game2.transformer_infer import WanMtxg2TransformerInfer
from lightx2v.models.networks.wan.infer.post_infer import WanPostInfer
from lightx2v.models.networks.wan.sf_model import WanSFModel
from lightx2v.models.networks.wan.weights.matrix_game2.pre_weights import WanMtxg2PreWeights
from lightx2v.models.networks.wan.weights.matrix_game2.transformer_weights import WanActionTransformerWeights
from lightx2v.utils.envs import *
from lightx2v.utils.utils import *
class WanSFMtxg2Model(WanSFModel):
pre_weight_class = WanMtxg2PreWeights
transformer_weight_class = WanActionTransformerWeights
def __init__(self, model_path, config, device):
super().__init__(model_path, config, device)
def _load_safetensor_to_dict(self, file_path, unified_dtype, sensitive_layer):
with safe_open(file_path, framework="pt", device=str(self.device)) as f:
return {key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())) for key in f.keys()}
def _load_ckpt(self, unified_dtype, sensitive_layer):
file_path = os.path.join(self.config["model_path"], f"{self.config['sub_model_folder']}/{self.config['sub_model_name']}")
_weight_dict = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
weight_dict = {}
for k, v in _weight_dict.items():
name = k[6:]
weight = v.to(torch.bfloat16).to(self.device)
weight_dict.update({name: weight})
del _weight_dict
return weight_dict
def _init_infer_class(self):
# update config by real model config
with open(os.path.join(self.config["model_path"], self.config["sub_model_folder"], "config.json")) as f:
model_config = json.load(f)
for k in model_config.keys():
self.config[k] = model_config[k]
self.pre_infer_class = WanMtxg2PreInfer
self.post_infer_class = WanPostInfer
self.transformer_infer_class = WanMtxg2TransformerInfer
...@@ -11,7 +11,8 @@ from lightx2v.models.networks.wan.model import WanModel ...@@ -11,7 +11,8 @@ from lightx2v.models.networks.wan.model import WanModel
class WanSFModel(WanModel): class WanSFModel(WanModel):
def __init__(self, model_path, config, device): def __init__(self, model_path, config, device):
super().__init__(model_path, config, device) super().__init__(model_path, config, device)
self.to_cuda() if config["model_cls"] not in ["wan2.1_sf_mtxg2"]:
self.to_cuda()
def _load_ckpt(self, unified_dtype, sensitive_layer): def _load_ckpt(self, unified_dtype, sensitive_layer):
sf_confg = self.config["sf_config"] sf_confg = self.config["sf_config"]
......
from lightx2v.common.modules.weight_module import WeightModule
from lightx2v.utils.registry_factory import (
CONV3D_WEIGHT_REGISTER,
LN_WEIGHT_REGISTER,
MM_WEIGHT_REGISTER,
)
class WanMtxg2PreWeights(WeightModule):
def __init__(self, config):
super().__init__()
self.in_dim = config["in_dim"]
self.dim = config["dim"]
self.patch_size = (1, 2, 2)
self.config = config
# patch
self.add_module(
"patch_embedding",
CONV3D_WEIGHT_REGISTER["Default"]("patch_embedding.weight", "patch_embedding.bias", stride=self.patch_size),
)
# time
self.add_module(
"time_embedding_0",
MM_WEIGHT_REGISTER["Default"]("time_embedding.0.weight", "time_embedding.0.bias"),
)
self.add_module(
"time_embedding_2",
MM_WEIGHT_REGISTER["Default"]("time_embedding.2.weight", "time_embedding.2.bias"),
)
self.add_module(
"time_projection_1",
MM_WEIGHT_REGISTER["Default"]("time_projection.1.weight", "time_projection.1.bias"),
)
# img_emb
self.add_module(
"img_emb_0",
LN_WEIGHT_REGISTER["Default"]("img_emb.proj.0.weight", "img_emb.proj.0.bias", eps=1e-5),
)
self.add_module(
"img_emb_1",
MM_WEIGHT_REGISTER["Default"]("img_emb.proj.1.weight", "img_emb.proj.1.bias"),
)
self.add_module(
"img_emb_3",
MM_WEIGHT_REGISTER["Default"]("img_emb.proj.3.weight", "img_emb.proj.3.bias"),
)
self.add_module(
"img_emb_4",
LN_WEIGHT_REGISTER["Default"]("img_emb.proj.4.weight", "img_emb.proj.4.bias", eps=1e-5),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment