Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
9826b8ca
Unverified
Commit
9826b8ca
authored
Nov 13, 2025
by
Watebear
Committed by
GitHub
Nov 13, 2025
Browse files
[feat]: support matrix game2 universal, gta_drive, templerun & streaming mode
parent
44e215f3
Changes
34
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2199 additions
and
5 deletions
+2199
-5
configs/matrix_game2/matrix_game2_gta_drive.json
configs/matrix_game2/matrix_game2_gta_drive.json
+72
-0
configs/matrix_game2/matrix_game2_gta_drive_streaming.json
configs/matrix_game2/matrix_game2_gta_drive_streaming.json
+72
-0
configs/matrix_game2/matrix_game2_templerun.json
configs/matrix_game2/matrix_game2_templerun.json
+65
-0
configs/matrix_game2/matrix_game2_templerun_streaming.json
configs/matrix_game2/matrix_game2_templerun_streaming.json
+72
-0
configs/matrix_game2/matrix_game2_universal.json
configs/matrix_game2/matrix_game2_universal.json
+72
-0
configs/matrix_game2/matrix_game2_universal_streaming.json
configs/matrix_game2/matrix_game2_universal_streaming.json
+72
-0
lightx2v/common/ops/mm/mm_weight.py
lightx2v/common/ops/mm/mm_weight.py
+2
-2
lightx2v/infer.py
lightx2v/infer.py
+2
-0
lightx2v/models/input_encoders/hf/q_linear.py
lightx2v/models/input_encoders/hf/q_linear.py
+2
-2
lightx2v/models/input_encoders/hf/wan/matrix_game2/__init__.py
...x2v/models/input_encoders/hf/wan/matrix_game2/__init__.py
+0
-0
lightx2v/models/input_encoders/hf/wan/matrix_game2/clip.py
lightx2v/models/input_encoders/hf/wan/matrix_game2/clip.py
+332
-0
lightx2v/models/input_encoders/hf/wan/matrix_game2/conditions.py
...v/models/input_encoders/hf/wan/matrix_game2/conditions.py
+203
-0
lightx2v/models/input_encoders/hf/wan/matrix_game2/tokenizers.py
...v/models/input_encoders/hf/wan/matrix_game2/tokenizers.py
+75
-0
lightx2v/models/networks/wan/infer/matrix_game2/posemb_layers.py
...v/models/networks/wan/infer/matrix_game2/posemb_layers.py
+291
-0
lightx2v/models/networks/wan/infer/matrix_game2/pre_infer.py
lightx2v/models/networks/wan/infer/matrix_game2/pre_infer.py
+98
-0
lightx2v/models/networks/wan/infer/matrix_game2/transformer_infer.py
...dels/networks/wan/infer/matrix_game2/transformer_infer.py
+668
-0
lightx2v/models/networks/wan/infer/module_io.py
lightx2v/models/networks/wan/infer/module_io.py
+1
-0
lightx2v/models/networks/wan/matrix_game2_model.py
lightx2v/models/networks/wan/matrix_game2_model.py
+48
-0
lightx2v/models/networks/wan/sf_model.py
lightx2v/models/networks/wan/sf_model.py
+2
-1
lightx2v/models/networks/wan/weights/matrix_game2/pre_weights.py
...v/models/networks/wan/weights/matrix_game2/pre_weights.py
+50
-0
No files found.
configs/matrix_game2/matrix_game2_gta_drive.json
0 → 100644
View file @
9826b8ca
{
"infer_steps"
:
50
,
"target_video_length"
:
150
,
"num_output_frames"
:
150
,
"text_len"
:
512
,
"target_height"
:
480
,
"target_width"
:
832
,
"self_attn_1_type"
:
"flash_attn2"
,
"cross_attn_1_type"
:
"flash_attn2"
,
"cross_attn_2_type"
:
"flash_attn2"
,
"seed"
:
0
,
"sample_guide_scale"
:
6
,
"sample_shift"
:
8
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
,
"sf_config"
:
{
"local_attn_size"
:
6
,
"shift"
:
5.0
,
"num_frame_per_block"
:
3
,
"num_transformer_blocks"
:
30
,
"frame_seq_length"
:
880
,
"num_output_frames"
:
150
,
"num_inference_steps"
:
1000
,
"denoising_step_list"
:
[
1000.0000
,
908.8427
,
713.9794
]
},
"sub_model_folder"
:
"gta_distilled_model"
,
"sub_model_name"
:
"gta_keyboard2dim.safetensors"
,
"mode"
:
"gta_drive"
,
"streaming"
:
false
,
"action_config"
:
{
"blocks"
:
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
],
"enable_keyboard"
:
true
,
"enable_mouse"
:
true
,
"heads_num"
:
16
,
"hidden_size"
:
128
,
"img_hidden_size"
:
1536
,
"keyboard_dim_in"
:
4
,
"keyboard_hidden_dim"
:
1024
,
"mouse_dim_in"
:
2
,
"mouse_hidden_dim"
:
1024
,
"mouse_qk_dim_list"
:
[
8
,
28
,
28
],
"patch_size"
:
[
1
,
2
,
2
],
"qk_norm"
:
true
,
"qkv_bias"
:
false
,
"rope_dim_list"
:
[
8
,
28
,
28
],
"rope_theta"
:
256
,
"vae_time_compression_ratio"
:
4
,
"windows_size"
:
3
},
"dim"
:
1536
,
"eps"
:
1e-06
,
"ffn_dim"
:
8960
,
"freq_dim"
:
256
,
"in_dim"
:
36
,
"inject_sample_info"
:
false
,
"model_type"
:
"i2v"
,
"num_heads"
:
12
,
"num_layers"
:
30
,
"out_dim"
:
16
}
configs/matrix_game2/matrix_game2_gta_drive_streaming.json
0 → 100644
View file @
9826b8ca
{
"infer_steps"
:
50
,
"target_video_length"
:
360
,
"num_output_frames"
:
360
,
"text_len"
:
512
,
"target_height"
:
480
,
"target_width"
:
832
,
"self_attn_1_type"
:
"flash_attn2"
,
"cross_attn_1_type"
:
"flash_attn2"
,
"cross_attn_2_type"
:
"flash_attn2"
,
"seed"
:
0
,
"sample_guide_scale"
:
6
,
"sample_shift"
:
8
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
,
"sf_config"
:
{
"local_attn_size"
:
6
,
"shift"
:
5.0
,
"num_frame_per_block"
:
3
,
"num_transformer_blocks"
:
30
,
"frame_seq_length"
:
880
,
"num_output_frames"
:
360
,
"num_inference_steps"
:
1000
,
"denoising_step_list"
:
[
1000.0000
,
908.8427
,
713.9794
]
},
"sub_model_folder"
:
"gta_distilled_model"
,
"sub_model_name"
:
"gta_keyboard2dim.safetensors"
,
"mode"
:
"gta_drive"
,
"streaming"
:
true
,
"action_config"
:
{
"blocks"
:
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
],
"enable_keyboard"
:
true
,
"enable_mouse"
:
true
,
"heads_num"
:
16
,
"hidden_size"
:
128
,
"img_hidden_size"
:
1536
,
"keyboard_dim_in"
:
4
,
"keyboard_hidden_dim"
:
1024
,
"mouse_dim_in"
:
2
,
"mouse_hidden_dim"
:
1024
,
"mouse_qk_dim_list"
:
[
8
,
28
,
28
],
"patch_size"
:
[
1
,
2
,
2
],
"qk_norm"
:
true
,
"qkv_bias"
:
false
,
"rope_dim_list"
:
[
8
,
28
,
28
],
"rope_theta"
:
256
,
"vae_time_compression_ratio"
:
4
,
"windows_size"
:
3
},
"dim"
:
1536
,
"eps"
:
1e-06
,
"ffn_dim"
:
8960
,
"freq_dim"
:
256
,
"in_dim"
:
36
,
"inject_sample_info"
:
false
,
"model_type"
:
"i2v"
,
"num_heads"
:
12
,
"num_layers"
:
30
,
"out_dim"
:
16
}
configs/matrix_game2/matrix_game2_templerun.json
0 → 100644
View file @
9826b8ca
{
"infer_steps"
:
50
,
"target_video_length"
:
150
,
"num_output_frames"
:
150
,
"text_len"
:
512
,
"target_height"
:
480
,
"target_width"
:
832
,
"self_attn_1_type"
:
"flash_attn2"
,
"cross_attn_1_type"
:
"flash_attn2"
,
"cross_attn_2_type"
:
"flash_attn2"
,
"seed"
:
0
,
"sample_guide_scale"
:
6
,
"sample_shift"
:
8
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
,
"sf_config"
:
{
"local_attn_size"
:
6
,
"shift"
:
5.0
,
"num_frame_per_block"
:
3
,
"num_transformer_blocks"
:
30
,
"frame_seq_length"
:
880
,
"num_output_frames"
:
150
,
"num_inference_steps"
:
1000
,
"denoising_step_list"
:
[
1000.0000
,
908.8427
,
713.9794
]
},
"sub_model_folder"
:
"templerun_distilled_model"
,
"sub_model_name"
:
"templerun_7dim_onlykey.safetensors"
,
"mode"
:
"templerun"
,
"streaming"
:
false
,
"action_config"
:
{
"blocks"
:
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
],
"enable_keyboard"
:
true
,
"enable_mouse"
:
false
,
"heads_num"
:
16
,
"hidden_size"
:
128
,
"img_hidden_size"
:
1536
,
"keyboard_dim_in"
:
7
,
"keyboard_hidden_dim"
:
1024
,
"patch_size"
:
[
1
,
2
,
2
],
"qk_norm"
:
true
,
"qkv_bias"
:
false
,
"rope_dim_list"
:
[
8
,
28
,
28
],
"rope_theta"
:
256
,
"vae_time_compression_ratio"
:
4
,
"windows_size"
:
3
},
"dim"
:
1536
,
"eps"
:
1e-06
,
"ffn_dim"
:
8960
,
"freq_dim"
:
256
,
"in_dim"
:
36
,
"inject_sample_info"
:
false
,
"model_type"
:
"i2v"
,
"num_heads"
:
12
,
"num_layers"
:
30
,
"out_dim"
:
16
}
configs/matrix_game2/matrix_game2_templerun_streaming.json
0 → 100644
View file @
9826b8ca
{
"infer_steps"
:
50
,
"target_video_length"
:
360
,
"num_output_frames"
:
360
,
"text_len"
:
512
,
"target_height"
:
480
,
"target_width"
:
832
,
"self_attn_1_type"
:
"flash_attn2"
,
"cross_attn_1_type"
:
"flash_attn2"
,
"cross_attn_2_type"
:
"flash_attn2"
,
"seed"
:
0
,
"sample_guide_scale"
:
6
,
"sample_shift"
:
8
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
,
"sf_config"
:
{
"local_attn_size"
:
6
,
"shift"
:
5.0
,
"num_frame_per_block"
:
3
,
"num_transformer_blocks"
:
30
,
"frame_seq_length"
:
880
,
"num_output_frames"
:
360
,
"num_inference_steps"
:
1000
,
"denoising_step_list"
:
[
1000.0000
,
908.8427
,
713.9794
]
},
"sub_model_folder"
:
"templerun_distilled_model"
,
"sub_model_name"
:
"templerun_7dim_onlykey.safetensors"
,
"mode"
:
"templerun"
,
"streaming"
:
true
,
"action_config"
:
{
"blocks"
:
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
],
"enable_keyboard"
:
true
,
"enable_mouse"
:
true
,
"heads_num"
:
16
,
"hidden_size"
:
128
,
"img_hidden_size"
:
1536
,
"keyboard_dim_in"
:
4
,
"keyboard_hidden_dim"
:
1024
,
"mouse_dim_in"
:
2
,
"mouse_hidden_dim"
:
1024
,
"mouse_qk_dim_list"
:
[
8
,
28
,
28
],
"patch_size"
:
[
1
,
2
,
2
],
"qk_norm"
:
true
,
"qkv_bias"
:
false
,
"rope_dim_list"
:
[
8
,
28
,
28
],
"rope_theta"
:
256
,
"vae_time_compression_ratio"
:
4
,
"windows_size"
:
3
},
"dim"
:
1536
,
"eps"
:
1e-06
,
"ffn_dim"
:
8960
,
"freq_dim"
:
256
,
"in_dim"
:
36
,
"inject_sample_info"
:
false
,
"model_type"
:
"i2v"
,
"num_heads"
:
12
,
"num_layers"
:
30
,
"out_dim"
:
16
}
configs/matrix_game2/matrix_game2_universal.json
0 → 100644
View file @
9826b8ca
{
"infer_steps"
:
50
,
"target_video_length"
:
150
,
"num_output_frames"
:
150
,
"text_len"
:
512
,
"target_height"
:
480
,
"target_width"
:
832
,
"self_attn_1_type"
:
"flash_attn2"
,
"cross_attn_1_type"
:
"flash_attn2"
,
"cross_attn_2_type"
:
"flash_attn2"
,
"seed"
:
0
,
"sample_guide_scale"
:
6
,
"sample_shift"
:
8
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
,
"sf_config"
:
{
"local_attn_size"
:
6
,
"shift"
:
5.0
,
"num_frame_per_block"
:
3
,
"num_transformer_blocks"
:
30
,
"frame_seq_length"
:
880
,
"num_output_frames"
:
150
,
"num_inference_steps"
:
1000
,
"denoising_step_list"
:
[
1000.0000
,
908.8427
,
713.9794
]
},
"sub_model_folder"
:
"base_distilled_model"
,
"sub_model_name"
:
"base_distill.safetensors"
,
"mode"
:
"universal"
,
"streaming"
:
false
,
"action_config"
:
{
"blocks"
:
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
],
"enable_keyboard"
:
true
,
"enable_mouse"
:
true
,
"heads_num"
:
16
,
"hidden_size"
:
128
,
"img_hidden_size"
:
1536
,
"keyboard_dim_in"
:
4
,
"keyboard_hidden_dim"
:
1024
,
"mouse_dim_in"
:
2
,
"mouse_hidden_dim"
:
1024
,
"mouse_qk_dim_list"
:
[
8
,
28
,
28
],
"patch_size"
:
[
1
,
2
,
2
],
"qk_norm"
:
true
,
"qkv_bias"
:
false
,
"rope_dim_list"
:
[
8
,
28
,
28
],
"rope_theta"
:
256
,
"vae_time_compression_ratio"
:
4
,
"windows_size"
:
3
},
"dim"
:
1536
,
"eps"
:
1e-06
,
"ffn_dim"
:
8960
,
"freq_dim"
:
256
,
"in_dim"
:
36
,
"inject_sample_info"
:
false
,
"model_type"
:
"i2v"
,
"num_heads"
:
12
,
"num_layers"
:
30
,
"out_dim"
:
16
}
configs/matrix_game2/matrix_game2_universal_streaming.json
0 → 100644
View file @
9826b8ca
{
"infer_steps"
:
50
,
"target_video_length"
:
360
,
"num_output_frames"
:
360
,
"text_len"
:
512
,
"target_height"
:
480
,
"target_width"
:
832
,
"self_attn_1_type"
:
"flash_attn2"
,
"cross_attn_1_type"
:
"flash_attn2"
,
"cross_attn_2_type"
:
"flash_attn2"
,
"seed"
:
0
,
"sample_guide_scale"
:
6
,
"sample_shift"
:
8
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
,
"sf_config"
:
{
"local_attn_size"
:
6
,
"shift"
:
5.0
,
"num_frame_per_block"
:
3
,
"num_transformer_blocks"
:
30
,
"frame_seq_length"
:
880
,
"num_output_frames"
:
360
,
"num_inference_steps"
:
1000
,
"denoising_step_list"
:
[
1000.0000
,
908.8427
,
713.9794
]
},
"sub_model_folder"
:
"base_distilled_model"
,
"sub_model_name"
:
"base_distill.safetensors"
,
"mode"
:
"universal"
,
"streaming"
:
true
,
"action_config"
:
{
"blocks"
:
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
],
"enable_keyboard"
:
true
,
"enable_mouse"
:
true
,
"heads_num"
:
16
,
"hidden_size"
:
128
,
"img_hidden_size"
:
1536
,
"keyboard_dim_in"
:
4
,
"keyboard_hidden_dim"
:
1024
,
"mouse_dim_in"
:
2
,
"mouse_hidden_dim"
:
1024
,
"mouse_qk_dim_list"
:
[
8
,
28
,
28
],
"patch_size"
:
[
1
,
2
,
2
],
"qk_norm"
:
true
,
"qkv_bias"
:
false
,
"rope_dim_list"
:
[
8
,
28
,
28
],
"rope_theta"
:
256
,
"vae_time_compression_ratio"
:
4
,
"windows_size"
:
3
},
"dim"
:
1536
,
"eps"
:
1e-06
,
"ffn_dim"
:
8960
,
"freq_dim"
:
256
,
"in_dim"
:
36
,
"inject_sample_info"
:
false
,
"model_type"
:
"i2v"
,
"num_heads"
:
12
,
"num_layers"
:
30
,
"out_dim"
:
16
}
lightx2v/common/ops/mm/mm_weight.py
View file @
9826b8ca
...
...
@@ -51,7 +51,7 @@ except ImportError:
try
:
from
torchao.quantization.utils
import
quant_int8_per_token_matmul
,
quantize_activation_per_token_absmax
except
ModuleNotFound
Error
:
except
Import
Error
:
quant_int8_per_token_matmul
,
quantize_activation_per_token_absmax
=
None
,
None
try
:
...
...
@@ -61,7 +61,7 @@ except ImportError:
try
:
import
marlin_cuda_quant
except
ModuleNotFound
Error
:
except
Import
Error
:
marlin_cuda_quant
=
None
...
...
lightx2v/infer.py
View file @
9826b8ca
...
...
@@ -9,6 +9,7 @@ from lightx2v.models.runners.qwen_image.qwen_image_runner import QwenImageRunner
from
lightx2v.models.runners.wan.wan_animate_runner
import
WanAnimateRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_audio_runner
import
Wan22AudioRunner
,
WanAudioRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_distill_runner
import
WanDistillRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_matrix_game2_runner
import
WanSFMtxg2Runner
# noqa: F401
from
lightx2v.models.runners.wan.wan_runner
import
Wan22MoeRunner
,
WanRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_sf_runner
import
WanSFRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_vace_runner
import
WanVaceRunner
# noqa: F401
...
...
@@ -39,6 +40,7 @@ def main():
"wan2.1_distill"
,
"wan2.1_vace"
,
"wan2.1_sf"
,
"wan2.1_sf_mtxg2"
,
"seko_talk"
,
"wan2.2_moe"
,
"wan2.2"
,
...
...
lightx2v/models/input_encoders/hf/q_linear.py
View file @
9826b8ca
...
...
@@ -3,7 +3,7 @@ import torch.nn as nn
try
:
from
vllm
import
_custom_ops
as
ops
except
ModuleNotFound
Error
:
except
Import
Error
:
ops
=
None
try
:
...
...
@@ -13,7 +13,7 @@ except ImportError:
try
:
from
torchao.quantization.utils
import
quant_int8_per_token_matmul
,
quantize_activation_per_token_absmax
except
ModuleNotFound
Error
:
except
Import
Error
:
quant_int8_per_token_matmul
,
quantize_activation_per_token_absmax
=
None
,
None
try
:
...
...
lightx2v/models/input_encoders/hf/wan/matrix_game2/__init__.py
0 → 100644
View file @
9826b8ca
lightx2v/models/input_encoders/hf/wan/matrix_game2/clip.py
0 → 100644
View file @
9826b8ca
# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import
logging
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torchvision.transforms
as
T
from
diffusers.models
import
ModelMixin
from
lightx2v.models.input_encoders.hf.wan.matrix_game2.tokenizers
import
HuggingfaceTokenizer
from
lightx2v.models.input_encoders.hf.wan.xlm_roberta.model
import
VisionTransformer
class
SelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
,
dropout
=
0.1
,
eps
=
1e-5
):
assert
dim
%
num_heads
==
0
super
().
__init__
()
self
.
dim
=
dim
self
.
num_heads
=
num_heads
self
.
head_dim
=
dim
//
num_heads
self
.
eps
=
eps
# layers
self
.
q
=
nn
.
Linear
(
dim
,
dim
)
self
.
k
=
nn
.
Linear
(
dim
,
dim
)
self
.
v
=
nn
.
Linear
(
dim
,
dim
)
self
.
o
=
nn
.
Linear
(
dim
,
dim
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
def
forward
(
self
,
x
,
mask
):
"""
x: [B, L, C].
"""
b
,
s
,
c
,
n
,
d
=
*
x
.
size
(),
self
.
num_heads
,
self
.
head_dim
# compute query, key, value
q
=
self
.
q
(
x
).
reshape
(
b
,
s
,
n
,
d
).
permute
(
0
,
2
,
1
,
3
)
k
=
self
.
k
(
x
).
reshape
(
b
,
s
,
n
,
d
).
permute
(
0
,
2
,
1
,
3
)
v
=
self
.
v
(
x
).
reshape
(
b
,
s
,
n
,
d
).
permute
(
0
,
2
,
1
,
3
)
# compute attention
p
=
self
.
dropout
.
p
if
self
.
training
else
0.0
x
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
mask
,
p
)
x
=
x
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
b
,
s
,
c
)
# output
x
=
self
.
o
(
x
)
x
=
self
.
dropout
(
x
)
return
x
class
AttentionBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
,
post_norm
,
dropout
=
0.1
,
eps
=
1e-5
):
super
().
__init__
()
self
.
dim
=
dim
self
.
num_heads
=
num_heads
self
.
post_norm
=
post_norm
self
.
eps
=
eps
# layers
self
.
attn
=
SelfAttention
(
dim
,
num_heads
,
dropout
,
eps
)
self
.
norm1
=
nn
.
LayerNorm
(
dim
,
eps
=
eps
)
self
.
ffn
=
nn
.
Sequential
(
nn
.
Linear
(
dim
,
dim
*
4
),
nn
.
GELU
(),
nn
.
Linear
(
dim
*
4
,
dim
),
nn
.
Dropout
(
dropout
))
self
.
norm2
=
nn
.
LayerNorm
(
dim
,
eps
=
eps
)
def
forward
(
self
,
x
,
mask
):
if
self
.
post_norm
:
x
=
self
.
norm1
(
x
+
self
.
attn
(
x
,
mask
))
x
=
self
.
norm2
(
x
+
self
.
ffn
(
x
))
else
:
x
=
x
+
self
.
attn
(
self
.
norm1
(
x
),
mask
)
x
=
x
+
self
.
ffn
(
self
.
norm2
(
x
))
return
x
class
XLMRoberta
(
nn
.
Module
):
"""
XLMRobertaModel with no pooler and no LM head.
"""
def
__init__
(
self
,
vocab_size
=
250002
,
max_seq_len
=
514
,
type_size
=
1
,
pad_id
=
1
,
dim
=
1024
,
num_heads
=
16
,
num_layers
=
24
,
post_norm
=
True
,
dropout
=
0.1
,
eps
=
1e-5
):
super
().
__init__
()
self
.
vocab_size
=
vocab_size
self
.
max_seq_len
=
max_seq_len
self
.
type_size
=
type_size
self
.
pad_id
=
pad_id
self
.
dim
=
dim
self
.
num_heads
=
num_heads
self
.
num_layers
=
num_layers
self
.
post_norm
=
post_norm
self
.
eps
=
eps
# embeddings
self
.
token_embedding
=
nn
.
Embedding
(
vocab_size
,
dim
,
padding_idx
=
pad_id
)
self
.
type_embedding
=
nn
.
Embedding
(
type_size
,
dim
)
self
.
pos_embedding
=
nn
.
Embedding
(
max_seq_len
,
dim
,
padding_idx
=
pad_id
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
# blocks
self
.
blocks
=
nn
.
ModuleList
([
AttentionBlock
(
dim
,
num_heads
,
post_norm
,
dropout
,
eps
)
for
_
in
range
(
num_layers
)])
# norm layer
self
.
norm
=
nn
.
LayerNorm
(
dim
,
eps
=
eps
)
def
forward
(
self
,
ids
):
"""
ids: [B, L] of torch.LongTensor.
"""
b
,
s
=
ids
.
shape
mask
=
ids
.
ne
(
self
.
pad_id
).
long
()
# embeddings
x
=
self
.
token_embedding
(
ids
)
+
self
.
type_embedding
(
torch
.
zeros_like
(
ids
))
+
self
.
pos_embedding
(
self
.
pad_id
+
torch
.
cumsum
(
mask
,
dim
=
1
)
*
mask
)
if
self
.
post_norm
:
x
=
self
.
norm
(
x
)
x
=
self
.
dropout
(
x
)
# blocks
mask
=
torch
.
where
(
mask
.
view
(
b
,
1
,
1
,
s
).
gt
(
0
),
0.0
,
torch
.
finfo
(
x
.
dtype
).
min
)
for
block
in
self
.
blocks
:
x
=
block
(
x
,
mask
)
# output
if
not
self
.
post_norm
:
x
=
self
.
norm
(
x
)
return
x
class
XLMRobertaWithHead
(
XLMRoberta
):
def
__init__
(
self
,
**
kwargs
):
self
.
out_dim
=
kwargs
.
pop
(
"out_dim"
)
super
().
__init__
(
**
kwargs
)
# head
mid_dim
=
(
self
.
dim
+
self
.
out_dim
)
//
2
self
.
head
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
dim
,
mid_dim
,
bias
=
False
),
nn
.
GELU
(),
nn
.
Linear
(
mid_dim
,
self
.
out_dim
,
bias
=
False
))
def
forward
(
self
,
ids
):
# xlm-roberta
x
=
super
().
forward
(
ids
)
# average pooling
mask
=
ids
.
ne
(
self
.
pad_id
).
unsqueeze
(
-
1
).
to
(
x
)
x
=
(
x
*
mask
).
sum
(
dim
=
1
)
/
mask
.
sum
(
dim
=
1
)
# head
x
=
self
.
head
(
x
)
return
x
class
XLMRobertaCLIP
(
nn
.
Module
):
def
__init__
(
self
,
dtype
=
torch
.
float16
,
embed_dim
=
1024
,
image_size
=
224
,
patch_size
=
14
,
vision_dim
=
1280
,
vision_mlp_ratio
=
4
,
vision_heads
=
16
,
vision_layers
=
32
,
vision_pool
=
"token"
,
vision_pre_norm
=
True
,
vision_post_norm
=
False
,
activation
=
"gelu"
,
vocab_size
=
250002
,
max_text_len
=
514
,
type_size
=
1
,
pad_id
=
1
,
attn_dropout
=
0.0
,
proj_dropout
=
0.0
,
embedding_dropout
=
0.0
,
norm_eps
=
1e-5
,
quantized
=
False
,
quant_scheme
=
None
,
text_dim
=
1024
,
text_heads
=
16
,
text_layers
=
24
,
text_post_norm
=
True
,
text_dropout
=
0.1
,
):
super
().
__init__
()
self
.
embed_dim
=
embed_dim
self
.
image_size
=
image_size
self
.
patch_size
=
patch_size
self
.
vision_dim
=
vision_dim
self
.
vision_mlp_ratio
=
vision_mlp_ratio
self
.
vision_heads
=
vision_heads
self
.
vision_layers
=
vision_layers
self
.
vision_pre_norm
=
vision_pre_norm
self
.
vision_post_norm
=
vision_post_norm
self
.
activation
=
activation
self
.
vocab_size
=
vocab_size
self
.
max_text_len
=
max_text_len
self
.
type_size
=
type_size
self
.
pad_id
=
pad_id
self
.
norm_eps
=
norm_eps
# models
self
.
visual
=
VisionTransformer
(
dtype
=
dtype
,
image_size
=
image_size
,
patch_size
=
patch_size
,
dim
=
vision_dim
,
mlp_ratio
=
vision_mlp_ratio
,
out_dim
=
embed_dim
,
num_heads
=
vision_heads
,
num_layers
=
vision_layers
,
pool_type
=
vision_pool
,
pre_norm
=
vision_pre_norm
,
post_norm
=
vision_post_norm
,
activation
=
activation
,
attn_dropout
=
attn_dropout
,
proj_dropout
=
proj_dropout
,
embedding_dropout
=
embedding_dropout
,
norm_eps
=
norm_eps
,
quantized
=
quantized
,
quant_scheme
=
quant_scheme
,
)
self
.
textual
=
XLMRobertaWithHead
(
vocab_size
=
vocab_size
,
max_seq_len
=
max_text_len
,
type_size
=
type_size
,
pad_id
=
pad_id
,
dim
=
text_dim
,
out_dim
=
embed_dim
,
num_heads
=
text_heads
,
num_layers
=
text_layers
,
post_norm
=
text_post_norm
,
dropout
=
text_dropout
,
)
self
.
log_scale
=
nn
.
Parameter
(
math
.
log
(
1
/
0.07
)
*
torch
.
ones
([]))
def
_clip
(
pretrained
=
False
,
pretrained_name
=
None
,
model_cls
=
XLMRobertaCLIP
,
return_transforms
=
False
,
return_tokenizer
=
False
,
tokenizer_padding
=
"eos"
,
dtype
=
torch
.
float32
,
device
=
"cpu"
,
**
kwargs
):
# init a model on device
with
torch
.
device
(
device
):
model
=
model_cls
(
**
kwargs
)
# set device
model
=
model
.
to
(
dtype
=
dtype
,
device
=
device
)
output
=
(
model
,)
# init transforms
if
return_transforms
:
# mean and std
if
"siglip"
in
pretrained_name
.
lower
():
mean
,
std
=
[
0.5
,
0.5
,
0.5
],
[
0.5
,
0.5
,
0.5
]
else
:
mean
=
[
0.48145466
,
0.4578275
,
0.40821073
]
std
=
[
0.26862954
,
0.26130258
,
0.27577711
]
# transforms
transforms
=
T
.
Compose
([
T
.
Resize
((
model
.
image_size
,
model
.
image_size
),
interpolation
=
T
.
InterpolationMode
.
BICUBIC
),
T
.
ToTensor
(),
T
.
Normalize
(
mean
=
mean
,
std
=
std
)])
output
+=
(
transforms
,)
return
output
[
0
]
if
len
(
output
)
==
1
else
output
def
clip_xlm_roberta_vit_h_14
(
pretrained
=
False
,
pretrained_name
=
"open-clip-xlm-roberta-large-vit-huge-14"
,
**
kwargs
):
cfg
=
dict
(
embed_dim
=
1024
,
image_size
=
224
,
patch_size
=
14
,
vision_dim
=
1280
,
vision_mlp_ratio
=
4
,
vision_heads
=
16
,
vision_layers
=
32
,
vision_pool
=
"token"
,
activation
=
"gelu"
,
vocab_size
=
250002
,
max_text_len
=
514
,
type_size
=
1
,
pad_id
=
1
,
text_dim
=
1024
,
text_heads
=
16
,
text_layers
=
24
,
text_post_norm
=
True
,
text_dropout
=
0.1
,
attn_dropout
=
0.0
,
proj_dropout
=
0.0
,
embedding_dropout
=
0.0
,
)
cfg
.
update
(
**
kwargs
)
return
_clip
(
pretrained
,
pretrained_name
,
XLMRobertaCLIP
,
**
cfg
)
class
CLIPModel
(
ModelMixin
):
def
__init__
(
self
,
checkpoint_path
,
tokenizer_path
):
super
().
__init__
()
self
.
checkpoint_path
=
checkpoint_path
self
.
tokenizer_path
=
tokenizer_path
# init model
self
.
model
,
self
.
transforms
=
clip_xlm_roberta_vit_h_14
(
pretrained
=
False
,
return_transforms
=
True
,
return_tokenizer
=
False
,
)
self
.
model
=
self
.
model
.
eval
().
requires_grad_
(
False
)
logging
.
info
(
f
"loading
{
checkpoint_path
}
"
)
self
.
model
.
load_state_dict
(
torch
.
load
(
checkpoint_path
,
map_location
=
"cpu"
))
# init tokenizer
self
.
tokenizer
=
HuggingfaceTokenizer
(
name
=
tokenizer_path
,
seq_len
=
self
.
model
.
max_text_len
-
2
,
clean
=
"whitespace"
)
def
encode_video
(
self
,
video
):
# preprocess
b
,
c
,
t
,
h
,
w
=
video
.
shape
video
=
video
.
transpose
(
1
,
2
)
video
=
video
.
reshape
(
b
*
t
,
c
,
h
,
w
)
size
=
(
self
.
model
.
image_size
,)
*
2
video
=
F
.
interpolate
(
video
,
size
=
size
,
mode
=
"bicubic"
,
align_corners
=
False
)
video
=
self
.
transforms
.
transforms
[
-
1
](
video
.
mul_
(
0.5
).
add_
(
0.5
))
# forward
with
torch
.
amp
.
autocast
(
dtype
=
self
.
dtype
,
device_type
=
self
.
device
.
type
):
out
=
self
.
model
.
visual
(
video
,
use_31_block
=
True
)
return
out
def
forward
(
self
,
videos
):
# preprocess
size
=
(
self
.
model
.
image_size
,)
*
2
videos
=
torch
.
cat
([
F
.
interpolate
(
u
.
transpose
(
0
,
1
),
size
=
size
,
mode
=
"bicubic"
,
align_corners
=
False
)
for
u
in
videos
])
videos
=
self
.
transforms
.
transforms
[
-
1
](
videos
.
mul_
(
0.5
).
add_
(
0.5
))
# forward
with
torch
.
amp
.
autocast
(
"cuda"
,
dtype
=
self
.
dtype
):
out
=
self
.
model
.
visual
(
videos
,
use_31_block
=
True
)
return
out
lightx2v/models/input_encoders/hf/wan/matrix_game2/conditions.py
0 → 100644
View file @
9826b8ca
import
random
import
torch
def
combine_data
(
data
,
num_frames
=
57
,
keyboard_dim
=
6
,
mouse
=
True
):
assert
num_frames
%
4
==
1
keyboard_condition
=
torch
.
zeros
((
num_frames
,
keyboard_dim
))
if
mouse
:
mouse_condition
=
torch
.
zeros
((
num_frames
,
2
))
current_frame
=
0
selections
=
[
12
]
while
current_frame
<
num_frames
:
rd_frame
=
selections
[
random
.
randint
(
0
,
len
(
selections
)
-
1
)]
rd
=
random
.
randint
(
0
,
len
(
data
)
-
1
)
k
=
data
[
rd
][
"keyboard_condition"
]
if
mouse
:
m
=
data
[
rd
][
"mouse_condition"
]
if
current_frame
==
0
:
keyboard_condition
[:
1
]
=
k
[:
1
]
if
mouse
:
mouse_condition
[:
1
]
=
m
[:
1
]
current_frame
=
1
else
:
rd_frame
=
min
(
rd_frame
,
num_frames
-
current_frame
)
repeat_time
=
rd_frame
//
4
keyboard_condition
[
current_frame
:
current_frame
+
rd_frame
]
=
k
.
repeat
(
repeat_time
,
1
)
if
mouse
:
mouse_condition
[
current_frame
:
current_frame
+
rd_frame
]
=
m
.
repeat
(
repeat_time
,
1
)
current_frame
+=
rd_frame
if
mouse
:
return
{
"keyboard_condition"
:
keyboard_condition
,
"mouse_condition"
:
mouse_condition
}
return
{
"keyboard_condition"
:
keyboard_condition
}
def
Bench_actions_universal
(
num_frames
,
num_samples_per_action
=
4
):
actions_single_action
=
[
"forward"
,
# "back",
"left"
,
"right"
,
]
actions_double_action
=
[
"forward_left"
,
"forward_right"
,
# "back_left",
# "back_right",
]
actions_single_camera
=
[
"camera_l"
,
"camera_r"
,
# "camera_ur",
# "camera_ul",
# "camera_dl",
# "camera_dr"
# "camera_up",
# "camera_down",
]
actions_to_test
=
actions_double_action
*
5
+
actions_single_camera
*
5
+
actions_single_action
*
5
for
action
in
actions_single_action
+
actions_double_action
:
for
camera
in
actions_single_camera
:
double_action
=
f
"
{
action
}
_
{
camera
}
"
actions_to_test
.
append
(
double_action
)
# print("length of actions: ", len(actions_to_test))
base_action
=
actions_single_action
+
actions_single_camera
KEYBOARD_IDX
=
{
"forward"
:
0
,
"back"
:
1
,
"left"
:
2
,
"right"
:
3
}
CAM_VALUE
=
0.1
CAMERA_VALUE_MAP
=
{
"camera_up"
:
[
CAM_VALUE
,
0
],
"camera_down"
:
[
-
CAM_VALUE
,
0
],
"camera_l"
:
[
0
,
-
CAM_VALUE
],
"camera_r"
:
[
0
,
CAM_VALUE
],
"camera_ur"
:
[
CAM_VALUE
,
CAM_VALUE
],
"camera_ul"
:
[
CAM_VALUE
,
-
CAM_VALUE
],
"camera_dr"
:
[
-
CAM_VALUE
,
CAM_VALUE
],
"camera_dl"
:
[
-
CAM_VALUE
,
-
CAM_VALUE
],
}
data
=
[]
for
action_name
in
actions_to_test
:
keyboard_condition
=
[[
0
,
0
,
0
,
0
]
for
_
in
range
(
num_samples_per_action
)]
mouse_condition
=
[[
0
,
0
]
for
_
in
range
(
num_samples_per_action
)]
for
sub_act
in
base_action
:
if
sub_act
not
in
action_name
:
# 只处理action_name包含的动作
continue
# print(f"action name: {action_name} sub_act: {sub_act}")
if
sub_act
in
CAMERA_VALUE_MAP
:
mouse_condition
=
[
CAMERA_VALUE_MAP
[
sub_act
]
for
_
in
range
(
num_samples_per_action
)]
elif
sub_act
in
KEYBOARD_IDX
:
col
=
KEYBOARD_IDX
[
sub_act
]
for
row
in
keyboard_condition
:
row
[
col
]
=
1
data
.
append
({
"keyboard_condition"
:
torch
.
tensor
(
keyboard_condition
),
"mouse_condition"
:
torch
.
tensor
(
mouse_condition
)})
return
combine_data
(
data
,
num_frames
,
keyboard_dim
=
4
,
mouse
=
True
)
def
Bench_actions_gta_drive
(
num_frames
,
num_samples_per_action
=
4
):
actions_single_action
=
[
"forward"
,
"back"
,
]
actions_single_camera
=
[
"camera_l"
,
"camera_r"
,
]
actions_to_test
=
actions_single_camera
*
2
+
actions_single_action
*
2
for
action
in
actions_single_action
:
for
camera
in
actions_single_camera
:
double_action
=
f
"
{
action
}
_
{
camera
}
"
actions_to_test
.
append
(
double_action
)
# print("length of actions: ", len(actions_to_test))
base_action
=
actions_single_action
+
actions_single_camera
KEYBOARD_IDX
=
{
"forward"
:
0
,
"back"
:
1
}
CAM_VALUE
=
0.1
CAMERA_VALUE_MAP
=
{
"camera_l"
:
[
0
,
-
CAM_VALUE
],
"camera_r"
:
[
0
,
CAM_VALUE
],
}
data
=
[]
for
action_name
in
actions_to_test
:
keyboard_condition
=
[[
0
,
0
]
for
_
in
range
(
num_samples_per_action
)]
mouse_condition
=
[[
0
,
0
]
for
_
in
range
(
num_samples_per_action
)]
for
sub_act
in
base_action
:
if
sub_act
not
in
action_name
:
# 只处理action_name包含的动作
continue
# print(f"action name: {action_name} sub_act: {sub_act}")
if
sub_act
in
CAMERA_VALUE_MAP
:
mouse_condition
=
[
CAMERA_VALUE_MAP
[
sub_act
]
for
_
in
range
(
num_samples_per_action
)]
elif
sub_act
in
KEYBOARD_IDX
:
col
=
KEYBOARD_IDX
[
sub_act
]
for
row
in
keyboard_condition
:
row
[
col
]
=
1
data
.
append
({
"keyboard_condition"
:
torch
.
tensor
(
keyboard_condition
),
"mouse_condition"
:
torch
.
tensor
(
mouse_condition
)})
return
combine_data
(
data
,
num_frames
,
keyboard_dim
=
2
,
mouse
=
True
)
def
Bench_actions_templerun
(
num_frames
,
num_samples_per_action
=
4
):
actions_single_action
=
[
"jump"
,
"slide"
,
"leftside"
,
"rightside"
,
"turnleft"
,
"turnright"
,
"nomove"
]
actions_to_test
=
actions_single_action
base_action
=
actions_single_action
KEYBOARD_IDX
=
{
"nomove"
:
0
,
"jump"
:
1
,
"slide"
:
2
,
"turnleft"
:
3
,
"turnright"
:
4
,
"leftside"
:
5
,
"rightside"
:
6
}
data
=
[]
for
action_name
in
actions_to_test
:
keyboard_condition
=
[[
0
,
0
,
0
,
0
,
0
,
0
,
0
]
for
_
in
range
(
num_samples_per_action
)]
for
sub_act
in
base_action
:
if
sub_act
not
in
action_name
:
# 只处理action_name包含的动作
continue
# print(f"action name: {action_name} sub_act: {sub_act}")
elif
sub_act
in
KEYBOARD_IDX
:
col
=
KEYBOARD_IDX
[
sub_act
]
for
row
in
keyboard_condition
:
row
[
col
]
=
1
data
.
append
({
"keyboard_condition"
:
torch
.
tensor
(
keyboard_condition
)})
return
combine_data
(
data
,
num_frames
,
keyboard_dim
=
7
,
mouse
=
False
)
class
MatrixGame2_Bench
:
def
__init__
(
self
):
self
.
deivce
=
torch
.
device
(
"cuda"
)
self
.
weight_dtype
=
torch
.
bfloat16
def
get_conditions
(
self
,
mode
,
num_frames
):
conditional_dict
=
{}
if
mode
==
"universal"
:
cond_data
=
Bench_actions_universal
(
num_frames
)
mouse_condition
=
cond_data
[
"mouse_condition"
].
unsqueeze
(
0
).
to
(
device
=
self
.
device
,
dtype
=
self
.
weight_dtype
)
conditional_dict
[
"mouse_cond"
]
=
mouse_condition
elif
mode
==
"gta_drive"
:
cond_data
=
Bench_actions_gta_drive
(
num_frames
)
mouse_condition
=
cond_data
[
"mouse_condition"
].
unsqueeze
(
0
).
to
(
device
=
self
.
device
,
dtype
=
self
.
weight_dtype
)
conditional_dict
[
"mouse_cond"
]
=
mouse_condition
else
:
cond_data
=
Bench_actions_templerun
(
num_frames
)
keyboard_condition
=
cond_data
[
"keyboard_condition"
].
unsqueeze
(
0
).
to
(
device
=
self
.
device
,
dtype
=
self
.
weight_dtype
)
conditional_dict
[
"keyboard_cond"
]
=
keyboard_condition
return
conditional_dict
lightx2v/models/input_encoders/hf/wan/matrix_game2/tokenizers.py
0 → 100644
View file @
9826b8ca
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import
html
import
string
import
ftfy
import
regex
as
re
from
transformers
import
AutoTokenizer
__all__
=
[
"HuggingfaceTokenizer"
]
def
basic_clean
(
text
):
text
=
ftfy
.
fix_text
(
text
)
text
=
html
.
unescape
(
html
.
unescape
(
text
))
return
text
.
strip
()
def
whitespace_clean
(
text
):
text
=
re
.
sub
(
r
"\s+"
,
" "
,
text
)
text
=
text
.
strip
()
return
text
def
canonicalize
(
text
,
keep_punctuation_exact_string
=
None
):
text
=
text
.
replace
(
"_"
,
" "
)
if
keep_punctuation_exact_string
:
text
=
keep_punctuation_exact_string
.
join
(
part
.
translate
(
str
.
maketrans
(
""
,
""
,
string
.
punctuation
))
for
part
in
text
.
split
(
keep_punctuation_exact_string
))
else
:
text
=
text
.
translate
(
str
.
maketrans
(
""
,
""
,
string
.
punctuation
))
text
=
text
.
lower
()
text
=
re
.
sub
(
r
"\s+"
,
" "
,
text
)
return
text
.
strip
()
class
HuggingfaceTokenizer
:
def
__init__
(
self
,
name
,
seq_len
=
None
,
clean
=
None
,
**
kwargs
):
assert
clean
in
(
None
,
"whitespace"
,
"lower"
,
"canonicalize"
)
self
.
name
=
name
self
.
seq_len
=
seq_len
self
.
clean
=
clean
# init tokenizer
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
name
,
**
kwargs
)
self
.
vocab_size
=
self
.
tokenizer
.
vocab_size
def
__call__
(
self
,
sequence
,
**
kwargs
):
return_mask
=
kwargs
.
pop
(
"return_mask"
,
False
)
# arguments
_kwargs
=
{
"return_tensors"
:
"pt"
}
if
self
.
seq_len
is
not
None
:
_kwargs
.
update
({
"padding"
:
"max_length"
,
"truncation"
:
True
,
"max_length"
:
self
.
seq_len
})
_kwargs
.
update
(
**
kwargs
)
# tokenization
if
isinstance
(
sequence
,
str
):
sequence
=
[
sequence
]
if
self
.
clean
:
sequence
=
[
self
.
_clean
(
u
)
for
u
in
sequence
]
ids
=
self
.
tokenizer
(
sequence
,
**
_kwargs
)
# output
if
return_mask
:
return
ids
.
input_ids
,
ids
.
attention_mask
else
:
return
ids
.
input_ids
def
_clean
(
self
,
text
):
if
self
.
clean
==
"whitespace"
:
text
=
whitespace_clean
(
basic_clean
(
text
))
elif
self
.
clean
==
"lower"
:
text
=
whitespace_clean
(
basic_clean
(
text
)).
lower
()
elif
self
.
clean
==
"canonicalize"
:
text
=
canonicalize
(
basic_clean
(
text
))
return
text
lightx2v/models/networks/wan/infer/matrix_game2/posemb_layers.py
0 → 100644
View file @
9826b8ca
from
typing
import
List
,
Tuple
,
Union
import
torch
def
_to_tuple
(
x
,
dim
=
2
):
if
isinstance
(
x
,
int
):
return
(
x
,)
*
dim
elif
len
(
x
)
==
dim
:
return
x
else
:
raise
ValueError
(
f
"Expected length
{
dim
}
or int, but got
{
x
}
"
)
def
get_meshgrid_nd
(
start
,
*
args
,
dim
=
2
):
"""
Get n-D meshgrid with start, stop and num.
Args:
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
n-tuples.
*args: See above.
dim (int): Dimension of the meshgrid. Defaults to 2.
Returns:
grid (np.ndarray): [dim, ...]
"""
if
len
(
args
)
==
0
:
# start is grid_size
num
=
_to_tuple
(
start
,
dim
=
dim
)
start
=
(
0
,)
*
dim
stop
=
num
elif
len
(
args
)
==
1
:
# start is start, args[0] is stop, step is 1
start
=
_to_tuple
(
start
,
dim
=
dim
)
stop
=
_to_tuple
(
args
[
0
],
dim
=
dim
)
num
=
[
stop
[
i
]
-
start
[
i
]
for
i
in
range
(
dim
)]
elif
len
(
args
)
==
2
:
# start is start, args[0] is stop, args[1] is num
start
=
_to_tuple
(
start
,
dim
=
dim
)
# Left-Top eg: 12,0
stop
=
_to_tuple
(
args
[
0
],
dim
=
dim
)
# Right-Bottom eg: 20,32
num
=
_to_tuple
(
args
[
1
],
dim
=
dim
)
# Target Size eg: 32,124
else
:
raise
ValueError
(
f
"len(args) should be 0, 1 or 2, but got
{
len
(
args
)
}
"
)
# PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
axis_grid
=
[]
for
i
in
range
(
dim
):
a
,
b
,
n
=
start
[
i
],
stop
[
i
],
num
[
i
]
g
=
torch
.
linspace
(
a
,
b
,
n
+
1
,
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())[:
n
]
axis_grid
.
append
(
g
)
grid
=
torch
.
meshgrid
(
*
axis_grid
,
indexing
=
"ij"
)
# dim x [W, H, D]
grid
=
torch
.
stack
(
grid
,
dim
=
0
)
# [dim, W, H, D]
return
grid
#################################################################################
# Rotary Positional Embedding Functions #
#################################################################################
# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
def
reshape_for_broadcast
(
freqs_cis
:
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
]],
x
:
torch
.
Tensor
,
head_first
=
False
,
):
"""
Reshape frequency tensor for broadcasting it with another tensor.
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
for the purpose of broadcasting the frequency tensor during element-wise operations.
Notes:
When using FlashMHAModified, head_first should be False.
When using Attention, head_first should be True.
Args:
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
x (torch.Tensor): Target tensor for broadcasting compatibility.
head_first (bool): head dimension first (except batch dim) or not.
Returns:
torch.Tensor: Reshaped frequency tensor.
Raises:
AssertionError: If the frequency tensor doesn't match the expected shape.
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
"""
ndim
=
x
.
ndim
assert
0
<=
1
<
ndim
if
isinstance
(
freqs_cis
,
tuple
):
# freqs_cis: (cos, sin) in real space
if
head_first
:
assert
freqs_cis
[
0
].
shape
==
(
x
.
shape
[
-
2
],
x
.
shape
[
-
1
],
),
f
"freqs_cis shape
{
freqs_cis
[
0
].
shape
}
does not match x shape
{
x
.
shape
}
"
shape
=
[
d
if
i
==
ndim
-
2
or
i
==
ndim
-
1
else
1
for
i
,
d
in
enumerate
(
x
.
shape
)]
else
:
# assert freqs_cis[0].shape == (
# x.shape[1],
# x.shape[-1],
# ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
# shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
shape
=
[
1
,
freqs_cis
[
0
].
shape
[
0
],
1
,
freqs_cis
[
0
].
shape
[
1
]]
return
freqs_cis
[
0
].
view
(
*
shape
),
freqs_cis
[
1
].
view
(
*
shape
)
else
:
# freqs_cis: values in complex space
if
head_first
:
assert
freqs_cis
.
shape
==
(
x
.
shape
[
-
2
],
x
.
shape
[
-
1
],
),
f
"freqs_cis shape
{
freqs_cis
.
shape
}
does not match x shape
{
x
.
shape
}
"
shape
=
[
d
if
i
==
ndim
-
2
or
i
==
ndim
-
1
else
1
for
i
,
d
in
enumerate
(
x
.
shape
)]
else
:
assert
freqs_cis
.
shape
==
(
x
.
shape
[
1
],
x
.
shape
[
-
1
],
),
f
"freqs_cis shape
{
freqs_cis
.
shape
}
does not match x shape
{
x
.
shape
}
"
shape
=
[
d
if
i
==
1
or
i
==
ndim
-
1
else
1
for
i
,
d
in
enumerate
(
x
.
shape
)]
return
freqs_cis
.
view
(
*
shape
)
def
rotate_half
(
x
):
x_real
,
x_imag
=
x
.
float
().
reshape
(
*
x
.
shape
[:
-
1
],
-
1
,
2
).
unbind
(
-
1
)
# [B, S, H, D//2]
return
torch
.
stack
([
-
x_imag
,
x_real
],
dim
=-
1
).
flatten
(
3
)
def
apply_rotary_emb
(
xq
:
torch
.
Tensor
,
xk
:
torch
.
Tensor
,
freqs_cis
:
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
head_first
:
bool
=
False
,
start_offset
:
int
=
0
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor.
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
returned as real tensors.
Args:
xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
head_first (bool): head dimension first (except batch dim) or not.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
# print(freqs_cis[0].shape, xq.shape, xk.shape)
xk_out
=
None
assert
isinstance
(
freqs_cis
,
tuple
)
if
isinstance
(
freqs_cis
,
tuple
):
cos
,
sin
=
reshape_for_broadcast
(
freqs_cis
,
xq
,
head_first
)
# [S, D]
cos
,
sin
=
cos
.
to
(
xq
.
device
),
sin
.
to
(
xq
.
device
)
# real * cos - imag * sin
# imag * cos + real * sin
xq_out
=
(
xq
.
float
()
*
cos
[:,
start_offset
:
start_offset
+
xq
.
shape
[
1
],
:,
:]
+
rotate_half
(
xq
.
float
())
*
sin
[:,
start_offset
:
start_offset
+
xq
.
shape
[
1
],
:,
:]).
type_as
(
xq
)
xk_out
=
(
xk
.
float
()
*
cos
[:,
start_offset
:
start_offset
+
xk
.
shape
[
1
],
:,
:]
+
rotate_half
(
xk
.
float
())
*
sin
[:,
start_offset
:
start_offset
+
xk
.
shape
[
1
],
:,
:]).
type_as
(
xk
)
else
:
# view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
xq_
=
torch
.
view_as_complex
(
xq
.
float
().
reshape
(
*
xq
.
shape
[:
-
1
],
-
1
,
2
))
# [B, S, H, D//2]
freqs_cis
=
reshape_for_broadcast
(
freqs_cis
,
xq_
,
head_first
).
to
(
xq
.
device
)
# [S, D//2] --> [1, S, 1, D//2]
# (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
# view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
xq_out
=
torch
.
view_as_real
(
xq_
*
freqs_cis
).
flatten
(
3
).
type_as
(
xq
)
xk_
=
torch
.
view_as_complex
(
xk
.
float
().
reshape
(
*
xk
.
shape
[:
-
1
],
-
1
,
2
))
# [B, S, H, D//2]
xk_out
=
torch
.
view_as_real
(
xk_
*
freqs_cis
).
flatten
(
3
).
type_as
(
xk
)
return
xq_out
,
xk_out
def
get_nd_rotary_pos_embed
(
rope_dim_list
,
start
,
*
args
,
theta
=
10000.0
,
use_real
=
False
,
theta_rescale_factor
:
Union
[
float
,
List
[
float
]]
=
1.0
,
interpolation_factor
:
Union
[
float
,
List
[
float
]]
=
1.0
,
):
"""
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
Args:
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
sum(rope_dim_list) should equal to head_dim of attention layer.
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
*args: See above.
theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
part and an imaginary part separately.
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
Returns:
pos_embed (torch.Tensor): [HW, D/2]
"""
grid
=
get_meshgrid_nd
(
start
,
*
args
,
dim
=
len
(
rope_dim_list
))
# [3, W, H, D] / [2, W, H]
if
isinstance
(
theta_rescale_factor
,
int
)
or
isinstance
(
theta_rescale_factor
,
float
):
theta_rescale_factor
=
[
theta_rescale_factor
]
*
len
(
rope_dim_list
)
elif
isinstance
(
theta_rescale_factor
,
list
)
and
len
(
theta_rescale_factor
)
==
1
:
theta_rescale_factor
=
[
theta_rescale_factor
[
0
]]
*
len
(
rope_dim_list
)
assert
len
(
theta_rescale_factor
)
==
len
(
rope_dim_list
),
"len(theta_rescale_factor) should equal to len(rope_dim_list)"
if
isinstance
(
interpolation_factor
,
int
)
or
isinstance
(
interpolation_factor
,
float
):
interpolation_factor
=
[
interpolation_factor
]
*
len
(
rope_dim_list
)
elif
isinstance
(
interpolation_factor
,
list
)
and
len
(
interpolation_factor
)
==
1
:
interpolation_factor
=
[
interpolation_factor
[
0
]]
*
len
(
rope_dim_list
)
assert
len
(
interpolation_factor
)
==
len
(
rope_dim_list
),
"len(interpolation_factor) should equal to len(rope_dim_list)"
# use 1/ndim of dimensions to encode grid_axis
embs
=
[]
for
i
in
range
(
len
(
rope_dim_list
)):
emb
=
get_1d_rotary_pos_embed
(
rope_dim_list
[
i
],
grid
[
i
].
reshape
(
-
1
),
theta
,
use_real
=
use_real
,
theta_rescale_factor
=
theta_rescale_factor
[
i
],
interpolation_factor
=
interpolation_factor
[
i
],
)
# 2 x [WHD, rope_dim_list[i]]
embs
.
append
(
emb
)
if
use_real
:
cos
=
torch
.
cat
([
emb
[
0
]
for
emb
in
embs
],
dim
=
1
)
# (WHD, D/2)
sin
=
torch
.
cat
([
emb
[
1
]
for
emb
in
embs
],
dim
=
1
)
# (WHD, D/2)
return
cos
,
sin
else
:
emb
=
torch
.
cat
(
embs
,
dim
=
1
)
# (WHD, D/2)
return
emb
def
get_1d_rotary_pos_embed
(
dim
:
int
,
pos
:
Union
[
torch
.
FloatTensor
,
int
],
theta
:
float
=
10000.0
,
use_real
:
bool
=
False
,
theta_rescale_factor
:
float
=
1.0
,
interpolation_factor
:
float
=
1.0
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
"""
Precompute the frequency tensor for complex exponential (cis) with given dimensions.
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.
Args:
dim (int): Dimension of the frequency tensor.
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
use_real (bool, optional): If True, return real part and imaginary part separately.
Otherwise, return complex numbers.
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
Returns:
freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
"""
if
isinstance
(
pos
,
int
):
pos
=
torch
.
arange
(
pos
,
device
=
torch
.
cuda
.
current_device
()).
float
()
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
if
theta_rescale_factor
!=
1.0
:
theta
*=
theta_rescale_factor
**
(
dim
/
(
dim
-
2
))
freqs
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim
,
2
,
device
=
torch
.
cuda
.
current_device
())[:
(
dim
//
2
)].
float
()
/
dim
))
# [D/2]
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
freqs
=
torch
.
outer
(
pos
*
interpolation_factor
,
freqs
)
# [S, D/2]
if
use_real
:
freqs_cos
=
freqs
.
cos
().
repeat_interleave
(
2
,
dim
=
1
)
# [S, D]
freqs_sin
=
freqs
.
sin
().
repeat_interleave
(
2
,
dim
=
1
)
# [S, D]
return
freqs_cos
,
freqs_sin
else
:
freqs_cis
=
torch
.
polar
(
torch
.
ones_like
(
freqs
),
freqs
)
# complex64 # [S, D/2]
return
freqs_cis
lightx2v/models/networks/wan/infer/matrix_game2/pre_infer.py
0 → 100644
View file @
9826b8ca
import
torch
from
lightx2v.models.networks.wan.infer.module_io
import
GridOutput
,
WanPreInferModuleOutput
from
lightx2v.models.networks.wan.infer.self_forcing.pre_infer
import
WanSFPreInfer
,
sinusoidal_embedding_1d
from
lightx2v.utils.envs
import
*
def
cond_current
(
conditional_dict
,
current_start_frame
,
num_frame_per_block
,
replace
=
None
,
mode
=
"universal"
):
new_cond
=
{}
new_cond
[
"cond_concat"
]
=
conditional_dict
[
"image_encoder_output"
][
"cond_concat"
][:,
:,
current_start_frame
:
current_start_frame
+
num_frame_per_block
]
new_cond
[
"visual_context"
]
=
conditional_dict
[
"image_encoder_output"
][
"visual_context"
]
if
replace
:
if
current_start_frame
==
0
:
last_frame_num
=
1
+
4
*
(
num_frame_per_block
-
1
)
else
:
last_frame_num
=
4
*
num_frame_per_block
final_frame
=
1
+
4
*
(
current_start_frame
+
num_frame_per_block
-
1
)
if
mode
!=
"templerun"
:
conditional_dict
[
"text_encoder_output"
][
"mouse_cond"
][:,
-
last_frame_num
+
final_frame
:
final_frame
]
=
replace
[
"mouse"
][
None
,
None
,
:].
repeat
(
1
,
last_frame_num
,
1
)
conditional_dict
[
"text_encoder_output"
][
"keyboard_cond"
][:,
-
last_frame_num
+
final_frame
:
final_frame
]
=
replace
[
"keyboard"
][
None
,
None
,
:].
repeat
(
1
,
last_frame_num
,
1
)
if
mode
!=
"templerun"
:
new_cond
[
"mouse_cond"
]
=
conditional_dict
[
"text_encoder_output"
][
"mouse_cond"
][:,
:
1
+
4
*
(
current_start_frame
+
num_frame_per_block
-
1
)]
new_cond
[
"keyboard_cond"
]
=
conditional_dict
[
"text_encoder_output"
][
"keyboard_cond"
][:,
:
1
+
4
*
(
current_start_frame
+
num_frame_per_block
-
1
)]
if
replace
:
return
new_cond
,
conditional_dict
else
:
return
new_cond
# @amp.autocast(enabled=False)
def
rope_params
(
max_seq_len
,
dim
,
theta
=
10000
):
assert
dim
%
2
==
0
freqs
=
torch
.
outer
(
torch
.
arange
(
max_seq_len
),
1.0
/
torch
.
pow
(
theta
,
torch
.
arange
(
0
,
dim
,
2
).
to
(
torch
.
float64
).
div
(
dim
)))
freqs
=
torch
.
polar
(
torch
.
ones_like
(
freqs
),
freqs
)
return
freqs
class
WanMtxg2PreInfer
(
WanSFPreInfer
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
d
=
config
[
"dim"
]
//
config
[
"num_heads"
]
self
.
freqs
=
torch
.
cat
([
rope_params
(
1024
,
d
-
4
*
(
d
//
6
)),
rope_params
(
1024
,
2
*
(
d
//
6
)),
rope_params
(
1024
,
2
*
(
d
//
6
))],
dim
=
1
).
to
(
torch
.
device
(
"cuda"
))
self
.
dim
=
config
[
"dim"
]
def
img_emb
(
self
,
weights
,
x
):
x
=
weights
.
img_emb_0
.
apply
(
x
)
x
=
weights
.
img_emb_1
.
apply
(
x
.
squeeze
(
0
))
x
=
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
"none"
)
x
=
weights
.
img_emb_3
.
apply
(
x
)
x
=
weights
.
img_emb_4
.
apply
(
x
)
x
=
x
.
unsqueeze
(
0
)
return
x
@
torch
.
no_grad
()
def
infer
(
self
,
weights
,
inputs
,
kv_start
=
0
,
kv_end
=
0
):
x
=
self
.
scheduler
.
latents_input
t
=
self
.
scheduler
.
timestep_input
current_start_frame
=
self
.
scheduler
.
seg_index
*
self
.
scheduler
.
num_frame_per_block
if
self
.
config
[
"streaming"
]:
current_actions
=
inputs
[
"current_actions"
]
current_conditional_dict
,
_
=
cond_current
(
inputs
,
current_start_frame
,
self
.
scheduler
.
num_frame_per_block
,
replace
=
current_actions
,
mode
=
self
.
config
[
"mode"
])
else
:
current_conditional_dict
=
cond_current
(
inputs
,
current_start_frame
,
self
.
scheduler
.
num_frame_per_block
,
mode
=
self
.
config
[
"mode"
])
cond_concat
=
current_conditional_dict
[
"cond_concat"
]
visual_context
=
current_conditional_dict
[
"visual_context"
]
x
=
torch
.
cat
([
x
.
unsqueeze
(
0
),
cond_concat
],
dim
=
1
)
# embeddings
x
=
weights
.
patch_embedding
.
apply
(
x
)
grid_sizes_t
,
grid_sizes_h
,
grid_sizes_w
=
torch
.
tensor
(
x
.
shape
[
2
:],
dtype
=
torch
.
long
)
grid_sizes
=
GridOutput
(
tensor
=
torch
.
tensor
([[
grid_sizes_t
,
grid_sizes_h
,
grid_sizes_w
]],
dtype
=
torch
.
int32
,
device
=
x
.
device
),
tuple
=
(
grid_sizes_t
,
grid_sizes_h
,
grid_sizes_w
))
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
# B FHW C'
seq_lens
=
torch
.
tensor
([
u
.
size
(
0
)
for
u
in
x
],
dtype
=
torch
.
long
,
device
=
torch
.
device
(
"cuda"
))
assert
seq_lens
[
0
]
<=
15
*
1
*
880
embed_tmp
=
sinusoidal_embedding_1d
(
self
.
freq_dim
,
t
.
flatten
()).
type_as
(
x
)
# torch.Size([3, 256])
embed
=
self
.
time_embedding
(
weights
,
embed_tmp
)
# torch.Size([3, 1536])
embed0
=
self
.
time_projection
(
weights
,
embed
).
unflatten
(
dim
=
0
,
sizes
=
t
.
shape
)
# context
context_lens
=
None
context
=
self
.
img_emb
(
weights
,
visual_context
)
return
WanPreInferModuleOutput
(
embed
=
embed
,
grid_sizes
=
grid_sizes
,
x
=
x
.
squeeze
(
0
),
embed0
=
embed0
.
squeeze
(
0
),
seq_lens
=
seq_lens
,
freqs
=
self
.
freqs
,
context
=
context
[
0
],
conditional_dict
=
current_conditional_dict
,
)
lightx2v/models/networks/wan/infer/matrix_game2/transformer_infer.py
0 → 100644
View file @
9826b8ca
This diff is collapsed.
Click to expand it.
lightx2v/models/networks/wan/infer/module_io.py
View file @
9826b8ca
...
...
@@ -20,3 +20,4 @@ class WanPreInferModuleOutput:
freqs
:
torch
.
Tensor
context
:
torch
.
Tensor
adapter_args
:
Dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
conditional_dict
:
Dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
lightx2v/models/networks/wan/matrix_game2_model.py
0 → 100644
View file @
9826b8ca
import
json
import
os
import
torch
from
safetensors
import
safe_open
from
lightx2v.models.networks.wan.infer.matrix_game2.pre_infer
import
WanMtxg2PreInfer
from
lightx2v.models.networks.wan.infer.matrix_game2.transformer_infer
import
WanMtxg2TransformerInfer
from
lightx2v.models.networks.wan.infer.post_infer
import
WanPostInfer
from
lightx2v.models.networks.wan.sf_model
import
WanSFModel
from
lightx2v.models.networks.wan.weights.matrix_game2.pre_weights
import
WanMtxg2PreWeights
from
lightx2v.models.networks.wan.weights.matrix_game2.transformer_weights
import
WanActionTransformerWeights
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.utils
import
*
class
WanSFMtxg2Model
(
WanSFModel
):
pre_weight_class
=
WanMtxg2PreWeights
transformer_weight_class
=
WanActionTransformerWeights
def
__init__
(
self
,
model_path
,
config
,
device
):
super
().
__init__
(
model_path
,
config
,
device
)
def
_load_safetensor_to_dict
(
self
,
file_path
,
unified_dtype
,
sensitive_layer
):
with
safe_open
(
file_path
,
framework
=
"pt"
,
device
=
str
(
self
.
device
))
as
f
:
return
{
key
:
(
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
())
if
unified_dtype
or
all
(
s
not
in
key
for
s
in
sensitive_layer
)
else
f
.
get_tensor
(
key
).
to
(
GET_SENSITIVE_DTYPE
()))
for
key
in
f
.
keys
()}
def
_load_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
file_path
=
os
.
path
.
join
(
self
.
config
[
"model_path"
],
f
"
{
self
.
config
[
'sub_model_folder'
]
}
/
{
self
.
config
[
'sub_model_name'
]
}
"
)
_weight_dict
=
self
.
_load_safetensor_to_dict
(
file_path
,
unified_dtype
,
sensitive_layer
)
weight_dict
=
{}
for
k
,
v
in
_weight_dict
.
items
():
name
=
k
[
6
:]
weight
=
v
.
to
(
torch
.
bfloat16
).
to
(
self
.
device
)
weight_dict
.
update
({
name
:
weight
})
del
_weight_dict
return
weight_dict
def
_init_infer_class
(
self
):
# update config by real model config
with
open
(
os
.
path
.
join
(
self
.
config
[
"model_path"
],
self
.
config
[
"sub_model_folder"
],
"config.json"
))
as
f
:
model_config
=
json
.
load
(
f
)
for
k
in
model_config
.
keys
():
self
.
config
[
k
]
=
model_config
[
k
]
self
.
pre_infer_class
=
WanMtxg2PreInfer
self
.
post_infer_class
=
WanPostInfer
self
.
transformer_infer_class
=
WanMtxg2TransformerInfer
lightx2v/models/networks/wan/sf_model.py
View file @
9826b8ca
...
...
@@ -11,7 +11,8 @@ from lightx2v.models.networks.wan.model import WanModel
class
WanSFModel
(
WanModel
):
def
__init__
(
self
,
model_path
,
config
,
device
):
super
().
__init__
(
model_path
,
config
,
device
)
self
.
to_cuda
()
if
config
[
"model_cls"
]
not
in
[
"wan2.1_sf_mtxg2"
]:
self
.
to_cuda
()
def
_load_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
sf_confg
=
self
.
config
[
"sf_config"
]
...
...
lightx2v/models/networks/wan/weights/matrix_game2/pre_weights.py
0 → 100644
View file @
9826b8ca
from
lightx2v.common.modules.weight_module
import
WeightModule
from
lightx2v.utils.registry_factory
import
(
CONV3D_WEIGHT_REGISTER
,
LN_WEIGHT_REGISTER
,
MM_WEIGHT_REGISTER
,
)
class
WanMtxg2PreWeights
(
WeightModule
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
in_dim
=
config
[
"in_dim"
]
self
.
dim
=
config
[
"dim"
]
self
.
patch_size
=
(
1
,
2
,
2
)
self
.
config
=
config
# patch
self
.
add_module
(
"patch_embedding"
,
CONV3D_WEIGHT_REGISTER
[
"Default"
](
"patch_embedding.weight"
,
"patch_embedding.bias"
,
stride
=
self
.
patch_size
),
)
# time
self
.
add_module
(
"time_embedding_0"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"time_embedding.0.weight"
,
"time_embedding.0.bias"
),
)
self
.
add_module
(
"time_embedding_2"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"time_embedding.2.weight"
,
"time_embedding.2.bias"
),
)
self
.
add_module
(
"time_projection_1"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"time_projection.1.weight"
,
"time_projection.1.bias"
),
)
# img_emb
self
.
add_module
(
"img_emb_0"
,
LN_WEIGHT_REGISTER
[
"Default"
](
"img_emb.proj.0.weight"
,
"img_emb.proj.0.bias"
,
eps
=
1e-5
),
)
self
.
add_module
(
"img_emb_1"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"img_emb.proj.1.weight"
,
"img_emb.proj.1.bias"
),
)
self
.
add_module
(
"img_emb_3"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"img_emb.proj.3.weight"
,
"img_emb.proj.3.bias"
),
)
self
.
add_module
(
"img_emb_4"
,
LN_WEIGHT_REGISTER
[
"Default"
](
"img_emb.proj.4.weight"
,
"img_emb.proj.4.bias"
,
eps
=
1e-5
),
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment