Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
9826b8ca
Unverified
Commit
9826b8ca
authored
Nov 13, 2025
by
Watebear
Committed by
GitHub
Nov 13, 2025
Browse files
[feat]: support matrix game2 universal, gta_drive, templerun & streaming mode
parent
44e215f3
Changes
34
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2199 additions
and
5 deletions
+2199
-5
configs/matrix_game2/matrix_game2_gta_drive.json
configs/matrix_game2/matrix_game2_gta_drive.json
+72
-0
configs/matrix_game2/matrix_game2_gta_drive_streaming.json
configs/matrix_game2/matrix_game2_gta_drive_streaming.json
+72
-0
configs/matrix_game2/matrix_game2_templerun.json
configs/matrix_game2/matrix_game2_templerun.json
+65
-0
configs/matrix_game2/matrix_game2_templerun_streaming.json
configs/matrix_game2/matrix_game2_templerun_streaming.json
+72
-0
configs/matrix_game2/matrix_game2_universal.json
configs/matrix_game2/matrix_game2_universal.json
+72
-0
configs/matrix_game2/matrix_game2_universal_streaming.json
configs/matrix_game2/matrix_game2_universal_streaming.json
+72
-0
lightx2v/common/ops/mm/mm_weight.py
lightx2v/common/ops/mm/mm_weight.py
+2
-2
lightx2v/infer.py
lightx2v/infer.py
+2
-0
lightx2v/models/input_encoders/hf/q_linear.py
lightx2v/models/input_encoders/hf/q_linear.py
+2
-2
lightx2v/models/input_encoders/hf/wan/matrix_game2/__init__.py
...x2v/models/input_encoders/hf/wan/matrix_game2/__init__.py
+0
-0
lightx2v/models/input_encoders/hf/wan/matrix_game2/clip.py
lightx2v/models/input_encoders/hf/wan/matrix_game2/clip.py
+332
-0
lightx2v/models/input_encoders/hf/wan/matrix_game2/conditions.py
...v/models/input_encoders/hf/wan/matrix_game2/conditions.py
+203
-0
lightx2v/models/input_encoders/hf/wan/matrix_game2/tokenizers.py
...v/models/input_encoders/hf/wan/matrix_game2/tokenizers.py
+75
-0
lightx2v/models/networks/wan/infer/matrix_game2/posemb_layers.py
...v/models/networks/wan/infer/matrix_game2/posemb_layers.py
+291
-0
lightx2v/models/networks/wan/infer/matrix_game2/pre_infer.py
lightx2v/models/networks/wan/infer/matrix_game2/pre_infer.py
+98
-0
lightx2v/models/networks/wan/infer/matrix_game2/transformer_infer.py
...dels/networks/wan/infer/matrix_game2/transformer_infer.py
+668
-0
lightx2v/models/networks/wan/infer/module_io.py
lightx2v/models/networks/wan/infer/module_io.py
+1
-0
lightx2v/models/networks/wan/matrix_game2_model.py
lightx2v/models/networks/wan/matrix_game2_model.py
+48
-0
lightx2v/models/networks/wan/sf_model.py
lightx2v/models/networks/wan/sf_model.py
+2
-1
lightx2v/models/networks/wan/weights/matrix_game2/pre_weights.py
...v/models/networks/wan/weights/matrix_game2/pre_weights.py
+50
-0
No files found.
configs/matrix_game2/matrix_game2_gta_drive.json
0 → 100644
View file @
9826b8ca
{
"infer_steps"
:
50
,
"target_video_length"
:
150
,
"num_output_frames"
:
150
,
"text_len"
:
512
,
"target_height"
:
480
,
"target_width"
:
832
,
"self_attn_1_type"
:
"flash_attn2"
,
"cross_attn_1_type"
:
"flash_attn2"
,
"cross_attn_2_type"
:
"flash_attn2"
,
"seed"
:
0
,
"sample_guide_scale"
:
6
,
"sample_shift"
:
8
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
,
"sf_config"
:
{
"local_attn_size"
:
6
,
"shift"
:
5.0
,
"num_frame_per_block"
:
3
,
"num_transformer_blocks"
:
30
,
"frame_seq_length"
:
880
,
"num_output_frames"
:
150
,
"num_inference_steps"
:
1000
,
"denoising_step_list"
:
[
1000.0000
,
908.8427
,
713.9794
]
},
"sub_model_folder"
:
"gta_distilled_model"
,
"sub_model_name"
:
"gta_keyboard2dim.safetensors"
,
"mode"
:
"gta_drive"
,
"streaming"
:
false
,
"action_config"
:
{
"blocks"
:
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
],
"enable_keyboard"
:
true
,
"enable_mouse"
:
true
,
"heads_num"
:
16
,
"hidden_size"
:
128
,
"img_hidden_size"
:
1536
,
"keyboard_dim_in"
:
4
,
"keyboard_hidden_dim"
:
1024
,
"mouse_dim_in"
:
2
,
"mouse_hidden_dim"
:
1024
,
"mouse_qk_dim_list"
:
[
8
,
28
,
28
],
"patch_size"
:
[
1
,
2
,
2
],
"qk_norm"
:
true
,
"qkv_bias"
:
false
,
"rope_dim_list"
:
[
8
,
28
,
28
],
"rope_theta"
:
256
,
"vae_time_compression_ratio"
:
4
,
"windows_size"
:
3
},
"dim"
:
1536
,
"eps"
:
1e-06
,
"ffn_dim"
:
8960
,
"freq_dim"
:
256
,
"in_dim"
:
36
,
"inject_sample_info"
:
false
,
"model_type"
:
"i2v"
,
"num_heads"
:
12
,
"num_layers"
:
30
,
"out_dim"
:
16
}
configs/matrix_game2/matrix_game2_gta_drive_streaming.json
0 → 100644
View file @
9826b8ca
{
"infer_steps"
:
50
,
"target_video_length"
:
360
,
"num_output_frames"
:
360
,
"text_len"
:
512
,
"target_height"
:
480
,
"target_width"
:
832
,
"self_attn_1_type"
:
"flash_attn2"
,
"cross_attn_1_type"
:
"flash_attn2"
,
"cross_attn_2_type"
:
"flash_attn2"
,
"seed"
:
0
,
"sample_guide_scale"
:
6
,
"sample_shift"
:
8
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
,
"sf_config"
:
{
"local_attn_size"
:
6
,
"shift"
:
5.0
,
"num_frame_per_block"
:
3
,
"num_transformer_blocks"
:
30
,
"frame_seq_length"
:
880
,
"num_output_frames"
:
360
,
"num_inference_steps"
:
1000
,
"denoising_step_list"
:
[
1000.0000
,
908.8427
,
713.9794
]
},
"sub_model_folder"
:
"gta_distilled_model"
,
"sub_model_name"
:
"gta_keyboard2dim.safetensors"
,
"mode"
:
"gta_drive"
,
"streaming"
:
true
,
"action_config"
:
{
"blocks"
:
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
],
"enable_keyboard"
:
true
,
"enable_mouse"
:
true
,
"heads_num"
:
16
,
"hidden_size"
:
128
,
"img_hidden_size"
:
1536
,
"keyboard_dim_in"
:
4
,
"keyboard_hidden_dim"
:
1024
,
"mouse_dim_in"
:
2
,
"mouse_hidden_dim"
:
1024
,
"mouse_qk_dim_list"
:
[
8
,
28
,
28
],
"patch_size"
:
[
1
,
2
,
2
],
"qk_norm"
:
true
,
"qkv_bias"
:
false
,
"rope_dim_list"
:
[
8
,
28
,
28
],
"rope_theta"
:
256
,
"vae_time_compression_ratio"
:
4
,
"windows_size"
:
3
},
"dim"
:
1536
,
"eps"
:
1e-06
,
"ffn_dim"
:
8960
,
"freq_dim"
:
256
,
"in_dim"
:
36
,
"inject_sample_info"
:
false
,
"model_type"
:
"i2v"
,
"num_heads"
:
12
,
"num_layers"
:
30
,
"out_dim"
:
16
}
configs/matrix_game2/matrix_game2_templerun.json
0 → 100644
View file @
9826b8ca
{
"infer_steps"
:
50
,
"target_video_length"
:
150
,
"num_output_frames"
:
150
,
"text_len"
:
512
,
"target_height"
:
480
,
"target_width"
:
832
,
"self_attn_1_type"
:
"flash_attn2"
,
"cross_attn_1_type"
:
"flash_attn2"
,
"cross_attn_2_type"
:
"flash_attn2"
,
"seed"
:
0
,
"sample_guide_scale"
:
6
,
"sample_shift"
:
8
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
,
"sf_config"
:
{
"local_attn_size"
:
6
,
"shift"
:
5.0
,
"num_frame_per_block"
:
3
,
"num_transformer_blocks"
:
30
,
"frame_seq_length"
:
880
,
"num_output_frames"
:
150
,
"num_inference_steps"
:
1000
,
"denoising_step_list"
:
[
1000.0000
,
908.8427
,
713.9794
]
},
"sub_model_folder"
:
"templerun_distilled_model"
,
"sub_model_name"
:
"templerun_7dim_onlykey.safetensors"
,
"mode"
:
"templerun"
,
"streaming"
:
false
,
"action_config"
:
{
"blocks"
:
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
],
"enable_keyboard"
:
true
,
"enable_mouse"
:
false
,
"heads_num"
:
16
,
"hidden_size"
:
128
,
"img_hidden_size"
:
1536
,
"keyboard_dim_in"
:
7
,
"keyboard_hidden_dim"
:
1024
,
"patch_size"
:
[
1
,
2
,
2
],
"qk_norm"
:
true
,
"qkv_bias"
:
false
,
"rope_dim_list"
:
[
8
,
28
,
28
],
"rope_theta"
:
256
,
"vae_time_compression_ratio"
:
4
,
"windows_size"
:
3
},
"dim"
:
1536
,
"eps"
:
1e-06
,
"ffn_dim"
:
8960
,
"freq_dim"
:
256
,
"in_dim"
:
36
,
"inject_sample_info"
:
false
,
"model_type"
:
"i2v"
,
"num_heads"
:
12
,
"num_layers"
:
30
,
"out_dim"
:
16
}
configs/matrix_game2/matrix_game2_templerun_streaming.json
0 → 100644
View file @
9826b8ca
{
"infer_steps"
:
50
,
"target_video_length"
:
360
,
"num_output_frames"
:
360
,
"text_len"
:
512
,
"target_height"
:
480
,
"target_width"
:
832
,
"self_attn_1_type"
:
"flash_attn2"
,
"cross_attn_1_type"
:
"flash_attn2"
,
"cross_attn_2_type"
:
"flash_attn2"
,
"seed"
:
0
,
"sample_guide_scale"
:
6
,
"sample_shift"
:
8
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
,
"sf_config"
:
{
"local_attn_size"
:
6
,
"shift"
:
5.0
,
"num_frame_per_block"
:
3
,
"num_transformer_blocks"
:
30
,
"frame_seq_length"
:
880
,
"num_output_frames"
:
360
,
"num_inference_steps"
:
1000
,
"denoising_step_list"
:
[
1000.0000
,
908.8427
,
713.9794
]
},
"sub_model_folder"
:
"templerun_distilled_model"
,
"sub_model_name"
:
"templerun_7dim_onlykey.safetensors"
,
"mode"
:
"templerun"
,
"streaming"
:
true
,
"action_config"
:
{
"blocks"
:
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
],
"enable_keyboard"
:
true
,
"enable_mouse"
:
true
,
"heads_num"
:
16
,
"hidden_size"
:
128
,
"img_hidden_size"
:
1536
,
"keyboard_dim_in"
:
4
,
"keyboard_hidden_dim"
:
1024
,
"mouse_dim_in"
:
2
,
"mouse_hidden_dim"
:
1024
,
"mouse_qk_dim_list"
:
[
8
,
28
,
28
],
"patch_size"
:
[
1
,
2
,
2
],
"qk_norm"
:
true
,
"qkv_bias"
:
false
,
"rope_dim_list"
:
[
8
,
28
,
28
],
"rope_theta"
:
256
,
"vae_time_compression_ratio"
:
4
,
"windows_size"
:
3
},
"dim"
:
1536
,
"eps"
:
1e-06
,
"ffn_dim"
:
8960
,
"freq_dim"
:
256
,
"in_dim"
:
36
,
"inject_sample_info"
:
false
,
"model_type"
:
"i2v"
,
"num_heads"
:
12
,
"num_layers"
:
30
,
"out_dim"
:
16
}
configs/matrix_game2/matrix_game2_universal.json
0 → 100644
View file @
9826b8ca
{
"infer_steps"
:
50
,
"target_video_length"
:
150
,
"num_output_frames"
:
150
,
"text_len"
:
512
,
"target_height"
:
480
,
"target_width"
:
832
,
"self_attn_1_type"
:
"flash_attn2"
,
"cross_attn_1_type"
:
"flash_attn2"
,
"cross_attn_2_type"
:
"flash_attn2"
,
"seed"
:
0
,
"sample_guide_scale"
:
6
,
"sample_shift"
:
8
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
,
"sf_config"
:
{
"local_attn_size"
:
6
,
"shift"
:
5.0
,
"num_frame_per_block"
:
3
,
"num_transformer_blocks"
:
30
,
"frame_seq_length"
:
880
,
"num_output_frames"
:
150
,
"num_inference_steps"
:
1000
,
"denoising_step_list"
:
[
1000.0000
,
908.8427
,
713.9794
]
},
"sub_model_folder"
:
"base_distilled_model"
,
"sub_model_name"
:
"base_distill.safetensors"
,
"mode"
:
"universal"
,
"streaming"
:
false
,
"action_config"
:
{
"blocks"
:
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
],
"enable_keyboard"
:
true
,
"enable_mouse"
:
true
,
"heads_num"
:
16
,
"hidden_size"
:
128
,
"img_hidden_size"
:
1536
,
"keyboard_dim_in"
:
4
,
"keyboard_hidden_dim"
:
1024
,
"mouse_dim_in"
:
2
,
"mouse_hidden_dim"
:
1024
,
"mouse_qk_dim_list"
:
[
8
,
28
,
28
],
"patch_size"
:
[
1
,
2
,
2
],
"qk_norm"
:
true
,
"qkv_bias"
:
false
,
"rope_dim_list"
:
[
8
,
28
,
28
],
"rope_theta"
:
256
,
"vae_time_compression_ratio"
:
4
,
"windows_size"
:
3
},
"dim"
:
1536
,
"eps"
:
1e-06
,
"ffn_dim"
:
8960
,
"freq_dim"
:
256
,
"in_dim"
:
36
,
"inject_sample_info"
:
false
,
"model_type"
:
"i2v"
,
"num_heads"
:
12
,
"num_layers"
:
30
,
"out_dim"
:
16
}
configs/matrix_game2/matrix_game2_universal_streaming.json
0 → 100644
View file @
9826b8ca
{
"infer_steps"
:
50
,
"target_video_length"
:
360
,
"num_output_frames"
:
360
,
"text_len"
:
512
,
"target_height"
:
480
,
"target_width"
:
832
,
"self_attn_1_type"
:
"flash_attn2"
,
"cross_attn_1_type"
:
"flash_attn2"
,
"cross_attn_2_type"
:
"flash_attn2"
,
"seed"
:
0
,
"sample_guide_scale"
:
6
,
"sample_shift"
:
8
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
,
"sf_config"
:
{
"local_attn_size"
:
6
,
"shift"
:
5.0
,
"num_frame_per_block"
:
3
,
"num_transformer_blocks"
:
30
,
"frame_seq_length"
:
880
,
"num_output_frames"
:
360
,
"num_inference_steps"
:
1000
,
"denoising_step_list"
:
[
1000.0000
,
908.8427
,
713.9794
]
},
"sub_model_folder"
:
"base_distilled_model"
,
"sub_model_name"
:
"base_distill.safetensors"
,
"mode"
:
"universal"
,
"streaming"
:
true
,
"action_config"
:
{
"blocks"
:
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
,
10
,
11
,
12
,
13
,
14
],
"enable_keyboard"
:
true
,
"enable_mouse"
:
true
,
"heads_num"
:
16
,
"hidden_size"
:
128
,
"img_hidden_size"
:
1536
,
"keyboard_dim_in"
:
4
,
"keyboard_hidden_dim"
:
1024
,
"mouse_dim_in"
:
2
,
"mouse_hidden_dim"
:
1024
,
"mouse_qk_dim_list"
:
[
8
,
28
,
28
],
"patch_size"
:
[
1
,
2
,
2
],
"qk_norm"
:
true
,
"qkv_bias"
:
false
,
"rope_dim_list"
:
[
8
,
28
,
28
],
"rope_theta"
:
256
,
"vae_time_compression_ratio"
:
4
,
"windows_size"
:
3
},
"dim"
:
1536
,
"eps"
:
1e-06
,
"ffn_dim"
:
8960
,
"freq_dim"
:
256
,
"in_dim"
:
36
,
"inject_sample_info"
:
false
,
"model_type"
:
"i2v"
,
"num_heads"
:
12
,
"num_layers"
:
30
,
"out_dim"
:
16
}
lightx2v/common/ops/mm/mm_weight.py
View file @
9826b8ca
...
@@ -51,7 +51,7 @@ except ImportError:
...
@@ -51,7 +51,7 @@ except ImportError:
try
:
try
:
from
torchao.quantization.utils
import
quant_int8_per_token_matmul
,
quantize_activation_per_token_absmax
from
torchao.quantization.utils
import
quant_int8_per_token_matmul
,
quantize_activation_per_token_absmax
except
ModuleNotFound
Error
:
except
Import
Error
:
quant_int8_per_token_matmul
,
quantize_activation_per_token_absmax
=
None
,
None
quant_int8_per_token_matmul
,
quantize_activation_per_token_absmax
=
None
,
None
try
:
try
:
...
@@ -61,7 +61,7 @@ except ImportError:
...
@@ -61,7 +61,7 @@ except ImportError:
try
:
try
:
import
marlin_cuda_quant
import
marlin_cuda_quant
except
ModuleNotFound
Error
:
except
Import
Error
:
marlin_cuda_quant
=
None
marlin_cuda_quant
=
None
...
...
lightx2v/infer.py
View file @
9826b8ca
...
@@ -9,6 +9,7 @@ from lightx2v.models.runners.qwen_image.qwen_image_runner import QwenImageRunner
...
@@ -9,6 +9,7 @@ from lightx2v.models.runners.qwen_image.qwen_image_runner import QwenImageRunner
from
lightx2v.models.runners.wan.wan_animate_runner
import
WanAnimateRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_animate_runner
import
WanAnimateRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_audio_runner
import
Wan22AudioRunner
,
WanAudioRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_audio_runner
import
Wan22AudioRunner
,
WanAudioRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_distill_runner
import
WanDistillRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_distill_runner
import
WanDistillRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_matrix_game2_runner
import
WanSFMtxg2Runner
# noqa: F401
from
lightx2v.models.runners.wan.wan_runner
import
Wan22MoeRunner
,
WanRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_runner
import
Wan22MoeRunner
,
WanRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_sf_runner
import
WanSFRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_sf_runner
import
WanSFRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_vace_runner
import
WanVaceRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_vace_runner
import
WanVaceRunner
# noqa: F401
...
@@ -39,6 +40,7 @@ def main():
...
@@ -39,6 +40,7 @@ def main():
"wan2.1_distill"
,
"wan2.1_distill"
,
"wan2.1_vace"
,
"wan2.1_vace"
,
"wan2.1_sf"
,
"wan2.1_sf"
,
"wan2.1_sf_mtxg2"
,
"seko_talk"
,
"seko_talk"
,
"wan2.2_moe"
,
"wan2.2_moe"
,
"wan2.2"
,
"wan2.2"
,
...
...
lightx2v/models/input_encoders/hf/q_linear.py
View file @
9826b8ca
...
@@ -3,7 +3,7 @@ import torch.nn as nn
...
@@ -3,7 +3,7 @@ import torch.nn as nn
try
:
try
:
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
except
ModuleNotFound
Error
:
except
Import
Error
:
ops
=
None
ops
=
None
try
:
try
:
...
@@ -13,7 +13,7 @@ except ImportError:
...
@@ -13,7 +13,7 @@ except ImportError:
try
:
try
:
from
torchao.quantization.utils
import
quant_int8_per_token_matmul
,
quantize_activation_per_token_absmax
from
torchao.quantization.utils
import
quant_int8_per_token_matmul
,
quantize_activation_per_token_absmax
except
ModuleNotFound
Error
:
except
Import
Error
:
quant_int8_per_token_matmul
,
quantize_activation_per_token_absmax
=
None
,
None
quant_int8_per_token_matmul
,
quantize_activation_per_token_absmax
=
None
,
None
try
:
try
:
...
...
lightx2v/models/input_encoders/hf/wan/matrix_game2/__init__.py
0 → 100644
View file @
9826b8ca
lightx2v/models/input_encoders/hf/wan/matrix_game2/clip.py
0 → 100644
View file @
9826b8ca
# Modified from ``https://github.com/openai/CLIP'' and ``https://github.com/mlfoundations/open_clip''
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import
logging
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torchvision.transforms
as
T
from
diffusers.models
import
ModelMixin
from
lightx2v.models.input_encoders.hf.wan.matrix_game2.tokenizers
import
HuggingfaceTokenizer
from
lightx2v.models.input_encoders.hf.wan.xlm_roberta.model
import
VisionTransformer
class
SelfAttention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
,
dropout
=
0.1
,
eps
=
1e-5
):
assert
dim
%
num_heads
==
0
super
().
__init__
()
self
.
dim
=
dim
self
.
num_heads
=
num_heads
self
.
head_dim
=
dim
//
num_heads
self
.
eps
=
eps
# layers
self
.
q
=
nn
.
Linear
(
dim
,
dim
)
self
.
k
=
nn
.
Linear
(
dim
,
dim
)
self
.
v
=
nn
.
Linear
(
dim
,
dim
)
self
.
o
=
nn
.
Linear
(
dim
,
dim
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
def
forward
(
self
,
x
,
mask
):
"""
x: [B, L, C].
"""
b
,
s
,
c
,
n
,
d
=
*
x
.
size
(),
self
.
num_heads
,
self
.
head_dim
# compute query, key, value
q
=
self
.
q
(
x
).
reshape
(
b
,
s
,
n
,
d
).
permute
(
0
,
2
,
1
,
3
)
k
=
self
.
k
(
x
).
reshape
(
b
,
s
,
n
,
d
).
permute
(
0
,
2
,
1
,
3
)
v
=
self
.
v
(
x
).
reshape
(
b
,
s
,
n
,
d
).
permute
(
0
,
2
,
1
,
3
)
# compute attention
p
=
self
.
dropout
.
p
if
self
.
training
else
0.0
x
=
F
.
scaled_dot_product_attention
(
q
,
k
,
v
,
mask
,
p
)
x
=
x
.
permute
(
0
,
2
,
1
,
3
).
reshape
(
b
,
s
,
c
)
# output
x
=
self
.
o
(
x
)
x
=
self
.
dropout
(
x
)
return
x
class
AttentionBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
,
post_norm
,
dropout
=
0.1
,
eps
=
1e-5
):
super
().
__init__
()
self
.
dim
=
dim
self
.
num_heads
=
num_heads
self
.
post_norm
=
post_norm
self
.
eps
=
eps
# layers
self
.
attn
=
SelfAttention
(
dim
,
num_heads
,
dropout
,
eps
)
self
.
norm1
=
nn
.
LayerNorm
(
dim
,
eps
=
eps
)
self
.
ffn
=
nn
.
Sequential
(
nn
.
Linear
(
dim
,
dim
*
4
),
nn
.
GELU
(),
nn
.
Linear
(
dim
*
4
,
dim
),
nn
.
Dropout
(
dropout
))
self
.
norm2
=
nn
.
LayerNorm
(
dim
,
eps
=
eps
)
def
forward
(
self
,
x
,
mask
):
if
self
.
post_norm
:
x
=
self
.
norm1
(
x
+
self
.
attn
(
x
,
mask
))
x
=
self
.
norm2
(
x
+
self
.
ffn
(
x
))
else
:
x
=
x
+
self
.
attn
(
self
.
norm1
(
x
),
mask
)
x
=
x
+
self
.
ffn
(
self
.
norm2
(
x
))
return
x
class
XLMRoberta
(
nn
.
Module
):
"""
XLMRobertaModel with no pooler and no LM head.
"""
def
__init__
(
self
,
vocab_size
=
250002
,
max_seq_len
=
514
,
type_size
=
1
,
pad_id
=
1
,
dim
=
1024
,
num_heads
=
16
,
num_layers
=
24
,
post_norm
=
True
,
dropout
=
0.1
,
eps
=
1e-5
):
super
().
__init__
()
self
.
vocab_size
=
vocab_size
self
.
max_seq_len
=
max_seq_len
self
.
type_size
=
type_size
self
.
pad_id
=
pad_id
self
.
dim
=
dim
self
.
num_heads
=
num_heads
self
.
num_layers
=
num_layers
self
.
post_norm
=
post_norm
self
.
eps
=
eps
# embeddings
self
.
token_embedding
=
nn
.
Embedding
(
vocab_size
,
dim
,
padding_idx
=
pad_id
)
self
.
type_embedding
=
nn
.
Embedding
(
type_size
,
dim
)
self
.
pos_embedding
=
nn
.
Embedding
(
max_seq_len
,
dim
,
padding_idx
=
pad_id
)
self
.
dropout
=
nn
.
Dropout
(
dropout
)
# blocks
self
.
blocks
=
nn
.
ModuleList
([
AttentionBlock
(
dim
,
num_heads
,
post_norm
,
dropout
,
eps
)
for
_
in
range
(
num_layers
)])
# norm layer
self
.
norm
=
nn
.
LayerNorm
(
dim
,
eps
=
eps
)
def
forward
(
self
,
ids
):
"""
ids: [B, L] of torch.LongTensor.
"""
b
,
s
=
ids
.
shape
mask
=
ids
.
ne
(
self
.
pad_id
).
long
()
# embeddings
x
=
self
.
token_embedding
(
ids
)
+
self
.
type_embedding
(
torch
.
zeros_like
(
ids
))
+
self
.
pos_embedding
(
self
.
pad_id
+
torch
.
cumsum
(
mask
,
dim
=
1
)
*
mask
)
if
self
.
post_norm
:
x
=
self
.
norm
(
x
)
x
=
self
.
dropout
(
x
)
# blocks
mask
=
torch
.
where
(
mask
.
view
(
b
,
1
,
1
,
s
).
gt
(
0
),
0.0
,
torch
.
finfo
(
x
.
dtype
).
min
)
for
block
in
self
.
blocks
:
x
=
block
(
x
,
mask
)
# output
if
not
self
.
post_norm
:
x
=
self
.
norm
(
x
)
return
x
class
XLMRobertaWithHead
(
XLMRoberta
):
def
__init__
(
self
,
**
kwargs
):
self
.
out_dim
=
kwargs
.
pop
(
"out_dim"
)
super
().
__init__
(
**
kwargs
)
# head
mid_dim
=
(
self
.
dim
+
self
.
out_dim
)
//
2
self
.
head
=
nn
.
Sequential
(
nn
.
Linear
(
self
.
dim
,
mid_dim
,
bias
=
False
),
nn
.
GELU
(),
nn
.
Linear
(
mid_dim
,
self
.
out_dim
,
bias
=
False
))
def
forward
(
self
,
ids
):
# xlm-roberta
x
=
super
().
forward
(
ids
)
# average pooling
mask
=
ids
.
ne
(
self
.
pad_id
).
unsqueeze
(
-
1
).
to
(
x
)
x
=
(
x
*
mask
).
sum
(
dim
=
1
)
/
mask
.
sum
(
dim
=
1
)
# head
x
=
self
.
head
(
x
)
return
x
class
XLMRobertaCLIP
(
nn
.
Module
):
def
__init__
(
self
,
dtype
=
torch
.
float16
,
embed_dim
=
1024
,
image_size
=
224
,
patch_size
=
14
,
vision_dim
=
1280
,
vision_mlp_ratio
=
4
,
vision_heads
=
16
,
vision_layers
=
32
,
vision_pool
=
"token"
,
vision_pre_norm
=
True
,
vision_post_norm
=
False
,
activation
=
"gelu"
,
vocab_size
=
250002
,
max_text_len
=
514
,
type_size
=
1
,
pad_id
=
1
,
attn_dropout
=
0.0
,
proj_dropout
=
0.0
,
embedding_dropout
=
0.0
,
norm_eps
=
1e-5
,
quantized
=
False
,
quant_scheme
=
None
,
text_dim
=
1024
,
text_heads
=
16
,
text_layers
=
24
,
text_post_norm
=
True
,
text_dropout
=
0.1
,
):
super
().
__init__
()
self
.
embed_dim
=
embed_dim
self
.
image_size
=
image_size
self
.
patch_size
=
patch_size
self
.
vision_dim
=
vision_dim
self
.
vision_mlp_ratio
=
vision_mlp_ratio
self
.
vision_heads
=
vision_heads
self
.
vision_layers
=
vision_layers
self
.
vision_pre_norm
=
vision_pre_norm
self
.
vision_post_norm
=
vision_post_norm
self
.
activation
=
activation
self
.
vocab_size
=
vocab_size
self
.
max_text_len
=
max_text_len
self
.
type_size
=
type_size
self
.
pad_id
=
pad_id
self
.
norm_eps
=
norm_eps
# models
self
.
visual
=
VisionTransformer
(
dtype
=
dtype
,
image_size
=
image_size
,
patch_size
=
patch_size
,
dim
=
vision_dim
,
mlp_ratio
=
vision_mlp_ratio
,
out_dim
=
embed_dim
,
num_heads
=
vision_heads
,
num_layers
=
vision_layers
,
pool_type
=
vision_pool
,
pre_norm
=
vision_pre_norm
,
post_norm
=
vision_post_norm
,
activation
=
activation
,
attn_dropout
=
attn_dropout
,
proj_dropout
=
proj_dropout
,
embedding_dropout
=
embedding_dropout
,
norm_eps
=
norm_eps
,
quantized
=
quantized
,
quant_scheme
=
quant_scheme
,
)
self
.
textual
=
XLMRobertaWithHead
(
vocab_size
=
vocab_size
,
max_seq_len
=
max_text_len
,
type_size
=
type_size
,
pad_id
=
pad_id
,
dim
=
text_dim
,
out_dim
=
embed_dim
,
num_heads
=
text_heads
,
num_layers
=
text_layers
,
post_norm
=
text_post_norm
,
dropout
=
text_dropout
,
)
self
.
log_scale
=
nn
.
Parameter
(
math
.
log
(
1
/
0.07
)
*
torch
.
ones
([]))
def
_clip
(
pretrained
=
False
,
pretrained_name
=
None
,
model_cls
=
XLMRobertaCLIP
,
return_transforms
=
False
,
return_tokenizer
=
False
,
tokenizer_padding
=
"eos"
,
dtype
=
torch
.
float32
,
device
=
"cpu"
,
**
kwargs
):
# init a model on device
with
torch
.
device
(
device
):
model
=
model_cls
(
**
kwargs
)
# set device
model
=
model
.
to
(
dtype
=
dtype
,
device
=
device
)
output
=
(
model
,)
# init transforms
if
return_transforms
:
# mean and std
if
"siglip"
in
pretrained_name
.
lower
():
mean
,
std
=
[
0.5
,
0.5
,
0.5
],
[
0.5
,
0.5
,
0.5
]
else
:
mean
=
[
0.48145466
,
0.4578275
,
0.40821073
]
std
=
[
0.26862954
,
0.26130258
,
0.27577711
]
# transforms
transforms
=
T
.
Compose
([
T
.
Resize
((
model
.
image_size
,
model
.
image_size
),
interpolation
=
T
.
InterpolationMode
.
BICUBIC
),
T
.
ToTensor
(),
T
.
Normalize
(
mean
=
mean
,
std
=
std
)])
output
+=
(
transforms
,)
return
output
[
0
]
if
len
(
output
)
==
1
else
output
def
clip_xlm_roberta_vit_h_14
(
pretrained
=
False
,
pretrained_name
=
"open-clip-xlm-roberta-large-vit-huge-14"
,
**
kwargs
):
cfg
=
dict
(
embed_dim
=
1024
,
image_size
=
224
,
patch_size
=
14
,
vision_dim
=
1280
,
vision_mlp_ratio
=
4
,
vision_heads
=
16
,
vision_layers
=
32
,
vision_pool
=
"token"
,
activation
=
"gelu"
,
vocab_size
=
250002
,
max_text_len
=
514
,
type_size
=
1
,
pad_id
=
1
,
text_dim
=
1024
,
text_heads
=
16
,
text_layers
=
24
,
text_post_norm
=
True
,
text_dropout
=
0.1
,
attn_dropout
=
0.0
,
proj_dropout
=
0.0
,
embedding_dropout
=
0.0
,
)
cfg
.
update
(
**
kwargs
)
return
_clip
(
pretrained
,
pretrained_name
,
XLMRobertaCLIP
,
**
cfg
)
class
CLIPModel
(
ModelMixin
):
def
__init__
(
self
,
checkpoint_path
,
tokenizer_path
):
super
().
__init__
()
self
.
checkpoint_path
=
checkpoint_path
self
.
tokenizer_path
=
tokenizer_path
# init model
self
.
model
,
self
.
transforms
=
clip_xlm_roberta_vit_h_14
(
pretrained
=
False
,
return_transforms
=
True
,
return_tokenizer
=
False
,
)
self
.
model
=
self
.
model
.
eval
().
requires_grad_
(
False
)
logging
.
info
(
f
"loading
{
checkpoint_path
}
"
)
self
.
model
.
load_state_dict
(
torch
.
load
(
checkpoint_path
,
map_location
=
"cpu"
))
# init tokenizer
self
.
tokenizer
=
HuggingfaceTokenizer
(
name
=
tokenizer_path
,
seq_len
=
self
.
model
.
max_text_len
-
2
,
clean
=
"whitespace"
)
def
encode_video
(
self
,
video
):
# preprocess
b
,
c
,
t
,
h
,
w
=
video
.
shape
video
=
video
.
transpose
(
1
,
2
)
video
=
video
.
reshape
(
b
*
t
,
c
,
h
,
w
)
size
=
(
self
.
model
.
image_size
,)
*
2
video
=
F
.
interpolate
(
video
,
size
=
size
,
mode
=
"bicubic"
,
align_corners
=
False
)
video
=
self
.
transforms
.
transforms
[
-
1
](
video
.
mul_
(
0.5
).
add_
(
0.5
))
# forward
with
torch
.
amp
.
autocast
(
dtype
=
self
.
dtype
,
device_type
=
self
.
device
.
type
):
out
=
self
.
model
.
visual
(
video
,
use_31_block
=
True
)
return
out
def
forward
(
self
,
videos
):
# preprocess
size
=
(
self
.
model
.
image_size
,)
*
2
videos
=
torch
.
cat
([
F
.
interpolate
(
u
.
transpose
(
0
,
1
),
size
=
size
,
mode
=
"bicubic"
,
align_corners
=
False
)
for
u
in
videos
])
videos
=
self
.
transforms
.
transforms
[
-
1
](
videos
.
mul_
(
0.5
).
add_
(
0.5
))
# forward
with
torch
.
amp
.
autocast
(
"cuda"
,
dtype
=
self
.
dtype
):
out
=
self
.
model
.
visual
(
videos
,
use_31_block
=
True
)
return
out
lightx2v/models/input_encoders/hf/wan/matrix_game2/conditions.py
0 → 100644
View file @
9826b8ca
import
random
import
torch
def
combine_data
(
data
,
num_frames
=
57
,
keyboard_dim
=
6
,
mouse
=
True
):
assert
num_frames
%
4
==
1
keyboard_condition
=
torch
.
zeros
((
num_frames
,
keyboard_dim
))
if
mouse
:
mouse_condition
=
torch
.
zeros
((
num_frames
,
2
))
current_frame
=
0
selections
=
[
12
]
while
current_frame
<
num_frames
:
rd_frame
=
selections
[
random
.
randint
(
0
,
len
(
selections
)
-
1
)]
rd
=
random
.
randint
(
0
,
len
(
data
)
-
1
)
k
=
data
[
rd
][
"keyboard_condition"
]
if
mouse
:
m
=
data
[
rd
][
"mouse_condition"
]
if
current_frame
==
0
:
keyboard_condition
[:
1
]
=
k
[:
1
]
if
mouse
:
mouse_condition
[:
1
]
=
m
[:
1
]
current_frame
=
1
else
:
rd_frame
=
min
(
rd_frame
,
num_frames
-
current_frame
)
repeat_time
=
rd_frame
//
4
keyboard_condition
[
current_frame
:
current_frame
+
rd_frame
]
=
k
.
repeat
(
repeat_time
,
1
)
if
mouse
:
mouse_condition
[
current_frame
:
current_frame
+
rd_frame
]
=
m
.
repeat
(
repeat_time
,
1
)
current_frame
+=
rd_frame
if
mouse
:
return
{
"keyboard_condition"
:
keyboard_condition
,
"mouse_condition"
:
mouse_condition
}
return
{
"keyboard_condition"
:
keyboard_condition
}
def
Bench_actions_universal
(
num_frames
,
num_samples_per_action
=
4
):
actions_single_action
=
[
"forward"
,
# "back",
"left"
,
"right"
,
]
actions_double_action
=
[
"forward_left"
,
"forward_right"
,
# "back_left",
# "back_right",
]
actions_single_camera
=
[
"camera_l"
,
"camera_r"
,
# "camera_ur",
# "camera_ul",
# "camera_dl",
# "camera_dr"
# "camera_up",
# "camera_down",
]
actions_to_test
=
actions_double_action
*
5
+
actions_single_camera
*
5
+
actions_single_action
*
5
for
action
in
actions_single_action
+
actions_double_action
:
for
camera
in
actions_single_camera
:
double_action
=
f
"
{
action
}
_
{
camera
}
"
actions_to_test
.
append
(
double_action
)
# print("length of actions: ", len(actions_to_test))
base_action
=
actions_single_action
+
actions_single_camera
KEYBOARD_IDX
=
{
"forward"
:
0
,
"back"
:
1
,
"left"
:
2
,
"right"
:
3
}
CAM_VALUE
=
0.1
CAMERA_VALUE_MAP
=
{
"camera_up"
:
[
CAM_VALUE
,
0
],
"camera_down"
:
[
-
CAM_VALUE
,
0
],
"camera_l"
:
[
0
,
-
CAM_VALUE
],
"camera_r"
:
[
0
,
CAM_VALUE
],
"camera_ur"
:
[
CAM_VALUE
,
CAM_VALUE
],
"camera_ul"
:
[
CAM_VALUE
,
-
CAM_VALUE
],
"camera_dr"
:
[
-
CAM_VALUE
,
CAM_VALUE
],
"camera_dl"
:
[
-
CAM_VALUE
,
-
CAM_VALUE
],
}
data
=
[]
for
action_name
in
actions_to_test
:
keyboard_condition
=
[[
0
,
0
,
0
,
0
]
for
_
in
range
(
num_samples_per_action
)]
mouse_condition
=
[[
0
,
0
]
for
_
in
range
(
num_samples_per_action
)]
for
sub_act
in
base_action
:
if
sub_act
not
in
action_name
:
# 只处理action_name包含的动作
continue
# print(f"action name: {action_name} sub_act: {sub_act}")
if
sub_act
in
CAMERA_VALUE_MAP
:
mouse_condition
=
[
CAMERA_VALUE_MAP
[
sub_act
]
for
_
in
range
(
num_samples_per_action
)]
elif
sub_act
in
KEYBOARD_IDX
:
col
=
KEYBOARD_IDX
[
sub_act
]
for
row
in
keyboard_condition
:
row
[
col
]
=
1
data
.
append
({
"keyboard_condition"
:
torch
.
tensor
(
keyboard_condition
),
"mouse_condition"
:
torch
.
tensor
(
mouse_condition
)})
return
combine_data
(
data
,
num_frames
,
keyboard_dim
=
4
,
mouse
=
True
)
def
Bench_actions_gta_drive
(
num_frames
,
num_samples_per_action
=
4
):
actions_single_action
=
[
"forward"
,
"back"
,
]
actions_single_camera
=
[
"camera_l"
,
"camera_r"
,
]
actions_to_test
=
actions_single_camera
*
2
+
actions_single_action
*
2
for
action
in
actions_single_action
:
for
camera
in
actions_single_camera
:
double_action
=
f
"
{
action
}
_
{
camera
}
"
actions_to_test
.
append
(
double_action
)
# print("length of actions: ", len(actions_to_test))
base_action
=
actions_single_action
+
actions_single_camera
KEYBOARD_IDX
=
{
"forward"
:
0
,
"back"
:
1
}
CAM_VALUE
=
0.1
CAMERA_VALUE_MAP
=
{
"camera_l"
:
[
0
,
-
CAM_VALUE
],
"camera_r"
:
[
0
,
CAM_VALUE
],
}
data
=
[]
for
action_name
in
actions_to_test
:
keyboard_condition
=
[[
0
,
0
]
for
_
in
range
(
num_samples_per_action
)]
mouse_condition
=
[[
0
,
0
]
for
_
in
range
(
num_samples_per_action
)]
for
sub_act
in
base_action
:
if
sub_act
not
in
action_name
:
# 只处理action_name包含的动作
continue
# print(f"action name: {action_name} sub_act: {sub_act}")
if
sub_act
in
CAMERA_VALUE_MAP
:
mouse_condition
=
[
CAMERA_VALUE_MAP
[
sub_act
]
for
_
in
range
(
num_samples_per_action
)]
elif
sub_act
in
KEYBOARD_IDX
:
col
=
KEYBOARD_IDX
[
sub_act
]
for
row
in
keyboard_condition
:
row
[
col
]
=
1
data
.
append
({
"keyboard_condition"
:
torch
.
tensor
(
keyboard_condition
),
"mouse_condition"
:
torch
.
tensor
(
mouse_condition
)})
return
combine_data
(
data
,
num_frames
,
keyboard_dim
=
2
,
mouse
=
True
)
def
Bench_actions_templerun
(
num_frames
,
num_samples_per_action
=
4
):
actions_single_action
=
[
"jump"
,
"slide"
,
"leftside"
,
"rightside"
,
"turnleft"
,
"turnright"
,
"nomove"
]
actions_to_test
=
actions_single_action
base_action
=
actions_single_action
KEYBOARD_IDX
=
{
"nomove"
:
0
,
"jump"
:
1
,
"slide"
:
2
,
"turnleft"
:
3
,
"turnright"
:
4
,
"leftside"
:
5
,
"rightside"
:
6
}
data
=
[]
for
action_name
in
actions_to_test
:
keyboard_condition
=
[[
0
,
0
,
0
,
0
,
0
,
0
,
0
]
for
_
in
range
(
num_samples_per_action
)]
for
sub_act
in
base_action
:
if
sub_act
not
in
action_name
:
# 只处理action_name包含的动作
continue
# print(f"action name: {action_name} sub_act: {sub_act}")
elif
sub_act
in
KEYBOARD_IDX
:
col
=
KEYBOARD_IDX
[
sub_act
]
for
row
in
keyboard_condition
:
row
[
col
]
=
1
data
.
append
({
"keyboard_condition"
:
torch
.
tensor
(
keyboard_condition
)})
return
combine_data
(
data
,
num_frames
,
keyboard_dim
=
7
,
mouse
=
False
)
class
MatrixGame2_Bench
:
def
__init__
(
self
):
self
.
deivce
=
torch
.
device
(
"cuda"
)
self
.
weight_dtype
=
torch
.
bfloat16
def
get_conditions
(
self
,
mode
,
num_frames
):
conditional_dict
=
{}
if
mode
==
"universal"
:
cond_data
=
Bench_actions_universal
(
num_frames
)
mouse_condition
=
cond_data
[
"mouse_condition"
].
unsqueeze
(
0
).
to
(
device
=
self
.
device
,
dtype
=
self
.
weight_dtype
)
conditional_dict
[
"mouse_cond"
]
=
mouse_condition
elif
mode
==
"gta_drive"
:
cond_data
=
Bench_actions_gta_drive
(
num_frames
)
mouse_condition
=
cond_data
[
"mouse_condition"
].
unsqueeze
(
0
).
to
(
device
=
self
.
device
,
dtype
=
self
.
weight_dtype
)
conditional_dict
[
"mouse_cond"
]
=
mouse_condition
else
:
cond_data
=
Bench_actions_templerun
(
num_frames
)
keyboard_condition
=
cond_data
[
"keyboard_condition"
].
unsqueeze
(
0
).
to
(
device
=
self
.
device
,
dtype
=
self
.
weight_dtype
)
conditional_dict
[
"keyboard_cond"
]
=
keyboard_condition
return
conditional_dict
lightx2v/models/input_encoders/hf/wan/matrix_game2/tokenizers.py
0 → 100644
View file @
9826b8ca
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import
html
import
string
import
ftfy
import
regex
as
re
from
transformers
import
AutoTokenizer
__all__
=
[
"HuggingfaceTokenizer"
]
def
basic_clean
(
text
):
text
=
ftfy
.
fix_text
(
text
)
text
=
html
.
unescape
(
html
.
unescape
(
text
))
return
text
.
strip
()
def
whitespace_clean
(
text
):
text
=
re
.
sub
(
r
"\s+"
,
" "
,
text
)
text
=
text
.
strip
()
return
text
def
canonicalize
(
text
,
keep_punctuation_exact_string
=
None
):
text
=
text
.
replace
(
"_"
,
" "
)
if
keep_punctuation_exact_string
:
text
=
keep_punctuation_exact_string
.
join
(
part
.
translate
(
str
.
maketrans
(
""
,
""
,
string
.
punctuation
))
for
part
in
text
.
split
(
keep_punctuation_exact_string
))
else
:
text
=
text
.
translate
(
str
.
maketrans
(
""
,
""
,
string
.
punctuation
))
text
=
text
.
lower
()
text
=
re
.
sub
(
r
"\s+"
,
" "
,
text
)
return
text
.
strip
()
class
HuggingfaceTokenizer
:
def
__init__
(
self
,
name
,
seq_len
=
None
,
clean
=
None
,
**
kwargs
):
assert
clean
in
(
None
,
"whitespace"
,
"lower"
,
"canonicalize"
)
self
.
name
=
name
self
.
seq_len
=
seq_len
self
.
clean
=
clean
# init tokenizer
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
name
,
**
kwargs
)
self
.
vocab_size
=
self
.
tokenizer
.
vocab_size
def
__call__
(
self
,
sequence
,
**
kwargs
):
return_mask
=
kwargs
.
pop
(
"return_mask"
,
False
)
# arguments
_kwargs
=
{
"return_tensors"
:
"pt"
}
if
self
.
seq_len
is
not
None
:
_kwargs
.
update
({
"padding"
:
"max_length"
,
"truncation"
:
True
,
"max_length"
:
self
.
seq_len
})
_kwargs
.
update
(
**
kwargs
)
# tokenization
if
isinstance
(
sequence
,
str
):
sequence
=
[
sequence
]
if
self
.
clean
:
sequence
=
[
self
.
_clean
(
u
)
for
u
in
sequence
]
ids
=
self
.
tokenizer
(
sequence
,
**
_kwargs
)
# output
if
return_mask
:
return
ids
.
input_ids
,
ids
.
attention_mask
else
:
return
ids
.
input_ids
def
_clean
(
self
,
text
):
if
self
.
clean
==
"whitespace"
:
text
=
whitespace_clean
(
basic_clean
(
text
))
elif
self
.
clean
==
"lower"
:
text
=
whitespace_clean
(
basic_clean
(
text
)).
lower
()
elif
self
.
clean
==
"canonicalize"
:
text
=
canonicalize
(
basic_clean
(
text
))
return
text
lightx2v/models/networks/wan/infer/matrix_game2/posemb_layers.py
0 → 100644
View file @
9826b8ca
from
typing
import
List
,
Tuple
,
Union
import
torch
def
_to_tuple
(
x
,
dim
=
2
):
if
isinstance
(
x
,
int
):
return
(
x
,)
*
dim
elif
len
(
x
)
==
dim
:
return
x
else
:
raise
ValueError
(
f
"Expected length
{
dim
}
or int, but got
{
x
}
"
)
def
get_meshgrid_nd
(
start
,
*
args
,
dim
=
2
):
"""
Get n-D meshgrid with start, stop and num.
Args:
start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
n-tuples.
*args: See above.
dim (int): Dimension of the meshgrid. Defaults to 2.
Returns:
grid (np.ndarray): [dim, ...]
"""
if
len
(
args
)
==
0
:
# start is grid_size
num
=
_to_tuple
(
start
,
dim
=
dim
)
start
=
(
0
,)
*
dim
stop
=
num
elif
len
(
args
)
==
1
:
# start is start, args[0] is stop, step is 1
start
=
_to_tuple
(
start
,
dim
=
dim
)
stop
=
_to_tuple
(
args
[
0
],
dim
=
dim
)
num
=
[
stop
[
i
]
-
start
[
i
]
for
i
in
range
(
dim
)]
elif
len
(
args
)
==
2
:
# start is start, args[0] is stop, args[1] is num
start
=
_to_tuple
(
start
,
dim
=
dim
)
# Left-Top eg: 12,0
stop
=
_to_tuple
(
args
[
0
],
dim
=
dim
)
# Right-Bottom eg: 20,32
num
=
_to_tuple
(
args
[
1
],
dim
=
dim
)
# Target Size eg: 32,124
else
:
raise
ValueError
(
f
"len(args) should be 0, 1 or 2, but got
{
len
(
args
)
}
"
)
# PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
axis_grid
=
[]
for
i
in
range
(
dim
):
a
,
b
,
n
=
start
[
i
],
stop
[
i
],
num
[
i
]
g
=
torch
.
linspace
(
a
,
b
,
n
+
1
,
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())[:
n
]
axis_grid
.
append
(
g
)
grid
=
torch
.
meshgrid
(
*
axis_grid
,
indexing
=
"ij"
)
# dim x [W, H, D]
grid
=
torch
.
stack
(
grid
,
dim
=
0
)
# [dim, W, H, D]
return
grid
#################################################################################
# Rotary Positional Embedding Functions #
#################################################################################
# https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
def
reshape_for_broadcast
(
freqs_cis
:
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
]],
x
:
torch
.
Tensor
,
head_first
=
False
,
):
"""
Reshape frequency tensor for broadcasting it with another tensor.
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
for the purpose of broadcasting the frequency tensor during element-wise operations.
Notes:
When using FlashMHAModified, head_first should be False.
When using Attention, head_first should be True.
Args:
freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
x (torch.Tensor): Target tensor for broadcasting compatibility.
head_first (bool): head dimension first (except batch dim) or not.
Returns:
torch.Tensor: Reshaped frequency tensor.
Raises:
AssertionError: If the frequency tensor doesn't match the expected shape.
AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
"""
ndim
=
x
.
ndim
assert
0
<=
1
<
ndim
if
isinstance
(
freqs_cis
,
tuple
):
# freqs_cis: (cos, sin) in real space
if
head_first
:
assert
freqs_cis
[
0
].
shape
==
(
x
.
shape
[
-
2
],
x
.
shape
[
-
1
],
),
f
"freqs_cis shape
{
freqs_cis
[
0
].
shape
}
does not match x shape
{
x
.
shape
}
"
shape
=
[
d
if
i
==
ndim
-
2
or
i
==
ndim
-
1
else
1
for
i
,
d
in
enumerate
(
x
.
shape
)]
else
:
# assert freqs_cis[0].shape == (
# x.shape[1],
# x.shape[-1],
# ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
# shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
shape
=
[
1
,
freqs_cis
[
0
].
shape
[
0
],
1
,
freqs_cis
[
0
].
shape
[
1
]]
return
freqs_cis
[
0
].
view
(
*
shape
),
freqs_cis
[
1
].
view
(
*
shape
)
else
:
# freqs_cis: values in complex space
if
head_first
:
assert
freqs_cis
.
shape
==
(
x
.
shape
[
-
2
],
x
.
shape
[
-
1
],
),
f
"freqs_cis shape
{
freqs_cis
.
shape
}
does not match x shape
{
x
.
shape
}
"
shape
=
[
d
if
i
==
ndim
-
2
or
i
==
ndim
-
1
else
1
for
i
,
d
in
enumerate
(
x
.
shape
)]
else
:
assert
freqs_cis
.
shape
==
(
x
.
shape
[
1
],
x
.
shape
[
-
1
],
),
f
"freqs_cis shape
{
freqs_cis
.
shape
}
does not match x shape
{
x
.
shape
}
"
shape
=
[
d
if
i
==
1
or
i
==
ndim
-
1
else
1
for
i
,
d
in
enumerate
(
x
.
shape
)]
return
freqs_cis
.
view
(
*
shape
)
def
rotate_half
(
x
):
x_real
,
x_imag
=
x
.
float
().
reshape
(
*
x
.
shape
[:
-
1
],
-
1
,
2
).
unbind
(
-
1
)
# [B, S, H, D//2]
return
torch
.
stack
([
-
x_imag
,
x_real
],
dim
=-
1
).
flatten
(
3
)
def
apply_rotary_emb
(
xq
:
torch
.
Tensor
,
xk
:
torch
.
Tensor
,
freqs_cis
:
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]],
head_first
:
bool
=
False
,
start_offset
:
int
=
0
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Apply rotary embeddings to input tensors using the given frequency tensor.
This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
returned as real tensors.
Args:
xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
head_first (bool): head dimension first (except batch dim) or not.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
"""
# print(freqs_cis[0].shape, xq.shape, xk.shape)
xk_out
=
None
assert
isinstance
(
freqs_cis
,
tuple
)
if
isinstance
(
freqs_cis
,
tuple
):
cos
,
sin
=
reshape_for_broadcast
(
freqs_cis
,
xq
,
head_first
)
# [S, D]
cos
,
sin
=
cos
.
to
(
xq
.
device
),
sin
.
to
(
xq
.
device
)
# real * cos - imag * sin
# imag * cos + real * sin
xq_out
=
(
xq
.
float
()
*
cos
[:,
start_offset
:
start_offset
+
xq
.
shape
[
1
],
:,
:]
+
rotate_half
(
xq
.
float
())
*
sin
[:,
start_offset
:
start_offset
+
xq
.
shape
[
1
],
:,
:]).
type_as
(
xq
)
xk_out
=
(
xk
.
float
()
*
cos
[:,
start_offset
:
start_offset
+
xk
.
shape
[
1
],
:,
:]
+
rotate_half
(
xk
.
float
())
*
sin
[:,
start_offset
:
start_offset
+
xk
.
shape
[
1
],
:,
:]).
type_as
(
xk
)
else
:
# view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
xq_
=
torch
.
view_as_complex
(
xq
.
float
().
reshape
(
*
xq
.
shape
[:
-
1
],
-
1
,
2
))
# [B, S, H, D//2]
freqs_cis
=
reshape_for_broadcast
(
freqs_cis
,
xq_
,
head_first
).
to
(
xq
.
device
)
# [S, D//2] --> [1, S, 1, D//2]
# (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
# view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
xq_out
=
torch
.
view_as_real
(
xq_
*
freqs_cis
).
flatten
(
3
).
type_as
(
xq
)
xk_
=
torch
.
view_as_complex
(
xk
.
float
().
reshape
(
*
xk
.
shape
[:
-
1
],
-
1
,
2
))
# [B, S, H, D//2]
xk_out
=
torch
.
view_as_real
(
xk_
*
freqs_cis
).
flatten
(
3
).
type_as
(
xk
)
return
xq_out
,
xk_out
def
get_nd_rotary_pos_embed
(
rope_dim_list
,
start
,
*
args
,
theta
=
10000.0
,
use_real
=
False
,
theta_rescale_factor
:
Union
[
float
,
List
[
float
]]
=
1.0
,
interpolation_factor
:
Union
[
float
,
List
[
float
]]
=
1.0
,
):
"""
This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
Args:
rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
sum(rope_dim_list) should equal to head_dim of attention layer.
start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
*args: See above.
theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
part and an imaginary part separately.
theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
Returns:
pos_embed (torch.Tensor): [HW, D/2]
"""
grid
=
get_meshgrid_nd
(
start
,
*
args
,
dim
=
len
(
rope_dim_list
))
# [3, W, H, D] / [2, W, H]
if
isinstance
(
theta_rescale_factor
,
int
)
or
isinstance
(
theta_rescale_factor
,
float
):
theta_rescale_factor
=
[
theta_rescale_factor
]
*
len
(
rope_dim_list
)
elif
isinstance
(
theta_rescale_factor
,
list
)
and
len
(
theta_rescale_factor
)
==
1
:
theta_rescale_factor
=
[
theta_rescale_factor
[
0
]]
*
len
(
rope_dim_list
)
assert
len
(
theta_rescale_factor
)
==
len
(
rope_dim_list
),
"len(theta_rescale_factor) should equal to len(rope_dim_list)"
if
isinstance
(
interpolation_factor
,
int
)
or
isinstance
(
interpolation_factor
,
float
):
interpolation_factor
=
[
interpolation_factor
]
*
len
(
rope_dim_list
)
elif
isinstance
(
interpolation_factor
,
list
)
and
len
(
interpolation_factor
)
==
1
:
interpolation_factor
=
[
interpolation_factor
[
0
]]
*
len
(
rope_dim_list
)
assert
len
(
interpolation_factor
)
==
len
(
rope_dim_list
),
"len(interpolation_factor) should equal to len(rope_dim_list)"
# use 1/ndim of dimensions to encode grid_axis
embs
=
[]
for
i
in
range
(
len
(
rope_dim_list
)):
emb
=
get_1d_rotary_pos_embed
(
rope_dim_list
[
i
],
grid
[
i
].
reshape
(
-
1
),
theta
,
use_real
=
use_real
,
theta_rescale_factor
=
theta_rescale_factor
[
i
],
interpolation_factor
=
interpolation_factor
[
i
],
)
# 2 x [WHD, rope_dim_list[i]]
embs
.
append
(
emb
)
if
use_real
:
cos
=
torch
.
cat
([
emb
[
0
]
for
emb
in
embs
],
dim
=
1
)
# (WHD, D/2)
sin
=
torch
.
cat
([
emb
[
1
]
for
emb
in
embs
],
dim
=
1
)
# (WHD, D/2)
return
cos
,
sin
else
:
emb
=
torch
.
cat
(
embs
,
dim
=
1
)
# (WHD, D/2)
return
emb
def
get_1d_rotary_pos_embed
(
dim
:
int
,
pos
:
Union
[
torch
.
FloatTensor
,
int
],
theta
:
float
=
10000.0
,
use_real
:
bool
=
False
,
theta_rescale_factor
:
float
=
1.0
,
interpolation_factor
:
float
=
1.0
,
)
->
Union
[
torch
.
Tensor
,
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]]:
"""
Precompute the frequency tensor for complex exponential (cis) with given dimensions.
(Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
and the end index 'end'. The 'theta' parameter scales the frequencies.
The returned tensor contains complex values in complex64 data type.
Args:
dim (int): Dimension of the frequency tensor.
pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
use_real (bool, optional): If True, return real part and imaginary part separately.
Otherwise, return complex numbers.
theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
Returns:
freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
"""
if
isinstance
(
pos
,
int
):
pos
=
torch
.
arange
(
pos
,
device
=
torch
.
cuda
.
current_device
()).
float
()
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
if
theta_rescale_factor
!=
1.0
:
theta
*=
theta_rescale_factor
**
(
dim
/
(
dim
-
2
))
freqs
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim
,
2
,
device
=
torch
.
cuda
.
current_device
())[:
(
dim
//
2
)].
float
()
/
dim
))
# [D/2]
# assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
freqs
=
torch
.
outer
(
pos
*
interpolation_factor
,
freqs
)
# [S, D/2]
if
use_real
:
freqs_cos
=
freqs
.
cos
().
repeat_interleave
(
2
,
dim
=
1
)
# [S, D]
freqs_sin
=
freqs
.
sin
().
repeat_interleave
(
2
,
dim
=
1
)
# [S, D]
return
freqs_cos
,
freqs_sin
else
:
freqs_cis
=
torch
.
polar
(
torch
.
ones_like
(
freqs
),
freqs
)
# complex64 # [S, D/2]
return
freqs_cis
lightx2v/models/networks/wan/infer/matrix_game2/pre_infer.py
0 → 100644
View file @
9826b8ca
import
torch
from
lightx2v.models.networks.wan.infer.module_io
import
GridOutput
,
WanPreInferModuleOutput
from
lightx2v.models.networks.wan.infer.self_forcing.pre_infer
import
WanSFPreInfer
,
sinusoidal_embedding_1d
from
lightx2v.utils.envs
import
*
def
cond_current
(
conditional_dict
,
current_start_frame
,
num_frame_per_block
,
replace
=
None
,
mode
=
"universal"
):
new_cond
=
{}
new_cond
[
"cond_concat"
]
=
conditional_dict
[
"image_encoder_output"
][
"cond_concat"
][:,
:,
current_start_frame
:
current_start_frame
+
num_frame_per_block
]
new_cond
[
"visual_context"
]
=
conditional_dict
[
"image_encoder_output"
][
"visual_context"
]
if
replace
:
if
current_start_frame
==
0
:
last_frame_num
=
1
+
4
*
(
num_frame_per_block
-
1
)
else
:
last_frame_num
=
4
*
num_frame_per_block
final_frame
=
1
+
4
*
(
current_start_frame
+
num_frame_per_block
-
1
)
if
mode
!=
"templerun"
:
conditional_dict
[
"text_encoder_output"
][
"mouse_cond"
][:,
-
last_frame_num
+
final_frame
:
final_frame
]
=
replace
[
"mouse"
][
None
,
None
,
:].
repeat
(
1
,
last_frame_num
,
1
)
conditional_dict
[
"text_encoder_output"
][
"keyboard_cond"
][:,
-
last_frame_num
+
final_frame
:
final_frame
]
=
replace
[
"keyboard"
][
None
,
None
,
:].
repeat
(
1
,
last_frame_num
,
1
)
if
mode
!=
"templerun"
:
new_cond
[
"mouse_cond"
]
=
conditional_dict
[
"text_encoder_output"
][
"mouse_cond"
][:,
:
1
+
4
*
(
current_start_frame
+
num_frame_per_block
-
1
)]
new_cond
[
"keyboard_cond"
]
=
conditional_dict
[
"text_encoder_output"
][
"keyboard_cond"
][:,
:
1
+
4
*
(
current_start_frame
+
num_frame_per_block
-
1
)]
if
replace
:
return
new_cond
,
conditional_dict
else
:
return
new_cond
# @amp.autocast(enabled=False)
def
rope_params
(
max_seq_len
,
dim
,
theta
=
10000
):
assert
dim
%
2
==
0
freqs
=
torch
.
outer
(
torch
.
arange
(
max_seq_len
),
1.0
/
torch
.
pow
(
theta
,
torch
.
arange
(
0
,
dim
,
2
).
to
(
torch
.
float64
).
div
(
dim
)))
freqs
=
torch
.
polar
(
torch
.
ones_like
(
freqs
),
freqs
)
return
freqs
class
WanMtxg2PreInfer
(
WanSFPreInfer
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
d
=
config
[
"dim"
]
//
config
[
"num_heads"
]
self
.
freqs
=
torch
.
cat
([
rope_params
(
1024
,
d
-
4
*
(
d
//
6
)),
rope_params
(
1024
,
2
*
(
d
//
6
)),
rope_params
(
1024
,
2
*
(
d
//
6
))],
dim
=
1
).
to
(
torch
.
device
(
"cuda"
))
self
.
dim
=
config
[
"dim"
]
def
img_emb
(
self
,
weights
,
x
):
x
=
weights
.
img_emb_0
.
apply
(
x
)
x
=
weights
.
img_emb_1
.
apply
(
x
.
squeeze
(
0
))
x
=
torch
.
nn
.
functional
.
gelu
(
x
,
approximate
=
"none"
)
x
=
weights
.
img_emb_3
.
apply
(
x
)
x
=
weights
.
img_emb_4
.
apply
(
x
)
x
=
x
.
unsqueeze
(
0
)
return
x
@
torch
.
no_grad
()
def
infer
(
self
,
weights
,
inputs
,
kv_start
=
0
,
kv_end
=
0
):
x
=
self
.
scheduler
.
latents_input
t
=
self
.
scheduler
.
timestep_input
current_start_frame
=
self
.
scheduler
.
seg_index
*
self
.
scheduler
.
num_frame_per_block
if
self
.
config
[
"streaming"
]:
current_actions
=
inputs
[
"current_actions"
]
current_conditional_dict
,
_
=
cond_current
(
inputs
,
current_start_frame
,
self
.
scheduler
.
num_frame_per_block
,
replace
=
current_actions
,
mode
=
self
.
config
[
"mode"
])
else
:
current_conditional_dict
=
cond_current
(
inputs
,
current_start_frame
,
self
.
scheduler
.
num_frame_per_block
,
mode
=
self
.
config
[
"mode"
])
cond_concat
=
current_conditional_dict
[
"cond_concat"
]
visual_context
=
current_conditional_dict
[
"visual_context"
]
x
=
torch
.
cat
([
x
.
unsqueeze
(
0
),
cond_concat
],
dim
=
1
)
# embeddings
x
=
weights
.
patch_embedding
.
apply
(
x
)
grid_sizes_t
,
grid_sizes_h
,
grid_sizes_w
=
torch
.
tensor
(
x
.
shape
[
2
:],
dtype
=
torch
.
long
)
grid_sizes
=
GridOutput
(
tensor
=
torch
.
tensor
([[
grid_sizes_t
,
grid_sizes_h
,
grid_sizes_w
]],
dtype
=
torch
.
int32
,
device
=
x
.
device
),
tuple
=
(
grid_sizes_t
,
grid_sizes_h
,
grid_sizes_w
))
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
# B FHW C'
seq_lens
=
torch
.
tensor
([
u
.
size
(
0
)
for
u
in
x
],
dtype
=
torch
.
long
,
device
=
torch
.
device
(
"cuda"
))
assert
seq_lens
[
0
]
<=
15
*
1
*
880
embed_tmp
=
sinusoidal_embedding_1d
(
self
.
freq_dim
,
t
.
flatten
()).
type_as
(
x
)
# torch.Size([3, 256])
embed
=
self
.
time_embedding
(
weights
,
embed_tmp
)
# torch.Size([3, 1536])
embed0
=
self
.
time_projection
(
weights
,
embed
).
unflatten
(
dim
=
0
,
sizes
=
t
.
shape
)
# context
context_lens
=
None
context
=
self
.
img_emb
(
weights
,
visual_context
)
return
WanPreInferModuleOutput
(
embed
=
embed
,
grid_sizes
=
grid_sizes
,
x
=
x
.
squeeze
(
0
),
embed0
=
embed0
.
squeeze
(
0
),
seq_lens
=
seq_lens
,
freqs
=
self
.
freqs
,
context
=
context
[
0
],
conditional_dict
=
current_conditional_dict
,
)
lightx2v/models/networks/wan/infer/matrix_game2/transformer_infer.py
0 → 100644
View file @
9826b8ca
import
math
import
torch
from
einops
import
rearrange
try
:
import
flash_attn_interface
FLASH_ATTN_3_AVAILABLE
=
True
except
ImportError
:
from
flash_attn
import
flash_attn_func
FLASH_ATTN_3_AVAILABLE
=
False
from
lightx2v.models.networks.wan.infer.matrix_game2.posemb_layers
import
apply_rotary_emb
,
get_nd_rotary_pos_embed
from
lightx2v.models.networks.wan.infer.self_forcing.transformer_infer
import
WanSFTransformerInfer
,
causal_rope_apply
class
WanMtxg2TransformerInfer
(
WanSFTransformerInfer
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
_initialize_kv_cache_mouse_and_keyboard
(
self
.
device
,
self
.
dtype
)
self
.
sink_size
=
0
self
.
vae_time_compression_ratio
=
config
[
"action_config"
][
"vae_time_compression_ratio"
]
self
.
windows_size
=
config
[
"action_config"
][
"windows_size"
]
self
.
patch_size
=
config
[
"action_config"
][
"patch_size"
]
self
.
rope_theta
=
config
[
"action_config"
][
"rope_theta"
]
self
.
enable_keyboard
=
config
[
"action_config"
][
"enable_keyboard"
]
self
.
heads_num
=
config
[
"action_config"
][
"heads_num"
]
self
.
hidden_size
=
config
[
"action_config"
][
"hidden_size"
]
self
.
img_hidden_size
=
config
[
"action_config"
][
"img_hidden_size"
]
self
.
keyboard_dim_in
=
config
[
"action_config"
][
"keyboard_dim_in"
]
self
.
keyboard_hidden_dim
=
config
[
"action_config"
][
"keyboard_hidden_dim"
]
self
.
qk_norm
=
config
[
"action_config"
][
"qk_norm"
]
self
.
qkv_bias
=
config
[
"action_config"
][
"qkv_bias"
]
self
.
rope_dim_list
=
config
[
"action_config"
][
"rope_dim_list"
]
self
.
freqs_cos
,
self
.
freqs_sin
=
self
.
get_rotary_pos_embed
(
7500
,
self
.
patch_size
[
1
],
self
.
patch_size
[
2
],
64
,
self
.
rope_dim_list
,
start_offset
=
0
)
self
.
enable_mouse
=
config
[
"action_config"
][
"enable_mouse"
]
if
self
.
enable_mouse
:
self
.
mouse_dim_in
=
config
[
"action_config"
][
"mouse_dim_in"
]
self
.
mouse_hidden_dim
=
config
[
"action_config"
][
"mouse_hidden_dim"
]
self
.
mouse_qk_dim_list
=
config
[
"action_config"
][
"mouse_qk_dim_list"
]
def
get_rotary_pos_embed
(
self
,
video_length
,
height
,
width
,
head_dim
,
rope_dim_list
=
None
,
start_offset
=
0
):
target_ndim
=
3
ndim
=
5
-
2
latents_size
=
[
video_length
+
start_offset
,
height
,
width
]
if
isinstance
(
self
.
patch_size
,
int
):
assert
all
(
s
%
self
.
patch_size
==
0
for
s
in
latents_size
),
f
"Latent size(last
{
ndim
}
dimensions) should be divisible by patch size(
{
self
.
patch_size
}
), but got
{
latents_size
}
."
rope_sizes
=
[
s
//
self
.
patch_size
for
s
in
latents_size
]
elif
isinstance
(
self
.
patch_size
,
list
):
assert
all
(
s
%
self
.
patch_size
[
idx
]
==
0
for
idx
,
s
in
enumerate
(
latents_size
)),
(
f
"Latent size(last
{
ndim
}
dimensions) should be divisible by patch size(
{
self
.
patch_size
}
), but got
{
latents_size
}
."
)
rope_sizes
=
[
s
//
self
.
patch_size
[
idx
]
for
idx
,
s
in
enumerate
(
latents_size
)]
if
len
(
rope_sizes
)
!=
target_ndim
:
rope_sizes
=
[
1
]
*
(
target_ndim
-
len
(
rope_sizes
))
+
rope_sizes
# time axis
if
rope_dim_list
is
None
:
rope_dim_list
=
[
head_dim
//
target_ndim
for
_
in
range
(
target_ndim
)]
assert
sum
(
rope_dim_list
)
==
head_dim
,
"sum(rope_dim_list) should equal to head_dim of attention layer"
freqs_cos
,
freqs_sin
=
get_nd_rotary_pos_embed
(
rope_dim_list
,
rope_sizes
,
theta
=
self
.
rope_theta
,
use_real
=
True
,
theta_rescale_factor
=
1
,
)
return
freqs_cos
[
-
video_length
*
rope_sizes
[
1
]
*
rope_sizes
[
2
]
//
self
.
patch_size
[
0
]
:],
freqs_sin
[
-
video_length
*
rope_sizes
[
1
]
*
rope_sizes
[
2
]
//
self
.
patch_size
[
0
]
:]
def
_initialize_kv_cache
(
self
,
dtype
,
device
):
"""
Initialize a Per-GPU KV cache for the Wan model.
"""
kv_cache1
=
[]
if
self
.
local_attn_size
!=
-
1
:
# Use the local attention size to compute the KV cache size
kv_cache_size
=
self
.
local_attn_size
*
self
.
frame_seq_length
else
:
# Use the default KV cache size
kv_cache_size
=
32760
for
_
in
range
(
self
.
num_transformer_blocks
):
kv_cache1
.
append
(
{
"k"
:
torch
.
zeros
((
kv_cache_size
,
12
,
128
)).
to
(
dtype
).
to
(
device
),
"v"
:
torch
.
zeros
((
kv_cache_size
,
12
,
128
)).
to
(
dtype
).
to
(
device
),
"global_end_index"
:
0
,
"local_end_index"
:
0
,
}
)
self
.
kv_cache1_default
=
kv_cache1
def
_initialize_kv_cache_mouse_and_keyboard
(
self
,
device
,
dtype
):
"""
Initialize a Per-GPU KV cache for the Wan model.
"""
kv_cache_mouse
=
[]
kv_cache_keyboard
=
[]
if
self
.
local_attn_size
!=
-
1
:
kv_cache_size
=
self
.
local_attn_size
else
:
kv_cache_size
=
15
*
1
for
_
in
range
(
self
.
num_transformer_blocks
):
kv_cache_keyboard
.
append
(
{
"k"
:
torch
.
zeros
([
1
,
kv_cache_size
,
16
,
64
],
dtype
=
dtype
,
device
=
device
),
"v"
:
torch
.
zeros
([
1
,
kv_cache_size
,
16
,
64
],
dtype
=
dtype
,
device
=
device
),
"global_end_index"
:
torch
.
tensor
([
0
],
dtype
=
torch
.
long
,
device
=
device
),
"local_end_index"
:
torch
.
tensor
([
0
],
dtype
=
torch
.
long
,
device
=
device
),
}
)
kv_cache_mouse
.
append
(
{
"k"
:
torch
.
zeros
([
self
.
frame_seq_length
,
kv_cache_size
,
16
,
64
],
dtype
=
dtype
,
device
=
device
),
"v"
:
torch
.
zeros
([
self
.
frame_seq_length
,
kv_cache_size
,
16
,
64
],
dtype
=
dtype
,
device
=
device
),
"global_end_index"
:
torch
.
tensor
([
0
],
dtype
=
torch
.
long
,
device
=
device
),
"local_end_index"
:
torch
.
tensor
([
0
],
dtype
=
torch
.
long
,
device
=
device
),
}
)
self
.
kv_cache_keyboard
=
kv_cache_keyboard
self
.
kv_cache_mouse
=
kv_cache_mouse
def
infer_self_attn_with_kvcache
(
self
,
phase
,
grid_sizes
,
x
,
seq_lens
,
freqs
,
shift_msa
,
scale_msa
):
if
hasattr
(
phase
,
"smooth_norm1_weight"
):
norm1_weight
=
(
1
+
scale_msa
.
squeeze
())
*
phase
.
smooth_norm1_weight
.
tensor
norm1_bias
=
shift_msa
.
squeeze
()
*
phase
.
smooth_norm1_bias
.
tensor
else
:
norm1_weight
=
1
+
scale_msa
.
squeeze
()
norm1_bias
=
shift_msa
.
squeeze
()
norm1_out
=
phase
.
norm1
.
apply
(
x
)
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
norm1_out
=
norm1_out
.
to
(
self
.
sensitive_layer_dtype
)
norm1_out
.
mul_
(
norm1_weight
[
0
:
1
,
:]).
add_
(
norm1_bias
[
0
:
1
,
:])
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
# False
norm1_out
=
norm1_out
.
to
(
self
.
infer_dtype
)
s
,
n
,
d
=
*
norm1_out
.
shape
[:
1
],
self
.
num_heads
,
self
.
head_dim
q0
=
phase
.
self_attn_q
.
apply
(
norm1_out
)
k0
=
phase
.
self_attn_k
.
apply
(
norm1_out
)
q
=
phase
.
self_attn_norm_q
.
apply
(
q0
).
view
(
s
,
n
,
d
)
k
=
phase
.
self_attn_norm_k
.
apply
(
k0
).
view
(
s
,
n
,
d
)
v
=
phase
.
self_attn_v
.
apply
(
norm1_out
).
view
(
s
,
n
,
d
)
seg_index
=
self
.
scheduler
.
seg_index
frame_seqlen
=
math
.
prod
(
grid_sizes
[
0
][
1
:]).
item
()
current_start
=
seg_index
*
self
.
num_frame_per_block
*
self
.
frame_seq_length
current_start_frame
=
current_start
//
frame_seqlen
q
=
causal_rope_apply
(
q
.
unsqueeze
(
0
),
grid_sizes
,
freqs
,
start_frame
=
current_start_frame
).
type_as
(
v
)[
0
]
k
=
causal_rope_apply
(
k
.
unsqueeze
(
0
),
grid_sizes
,
freqs
,
start_frame
=
current_start_frame
).
type_as
(
v
)[
0
]
current_end
=
current_start
+
q
.
shape
[
0
]
sink_tokens
=
self
.
sink_size
*
frame_seqlen
kv_cache_size
=
self
.
kv_cache1
[
self
.
block_idx
][
"k"
].
shape
[
0
]
num_new_tokens
=
q
.
shape
[
0
]
if
(
current_end
>
self
.
kv_cache1
[
self
.
block_idx
][
"global_end_index"
])
and
(
num_new_tokens
+
self
.
kv_cache1
[
self
.
block_idx
][
"local_end_index"
]
>
kv_cache_size
):
num_evicted_tokens
=
num_new_tokens
+
self
.
kv_cache1
[
self
.
block_idx
][
"local_end_index"
]
-
kv_cache_size
num_rolled_tokens
=
self
.
kv_cache1
[
self
.
block_idx
][
"local_end_index"
]
-
num_evicted_tokens
-
sink_tokens
self
.
kv_cache1
[
self
.
block_idx
][
"k"
][
sink_tokens
:
sink_tokens
+
num_rolled_tokens
]
=
self
.
kv_cache1
[
self
.
block_idx
][
"k"
][
sink_tokens
+
num_evicted_tokens
:
sink_tokens
+
num_evicted_tokens
+
num_rolled_tokens
].
clone
()
self
.
kv_cache1
[
self
.
block_idx
][
"v"
][
sink_tokens
:
sink_tokens
+
num_rolled_tokens
]
=
self
.
kv_cache1
[
self
.
block_idx
][
"v"
][
sink_tokens
+
num_evicted_tokens
:
sink_tokens
+
num_evicted_tokens
+
num_rolled_tokens
].
clone
()
# Insert the new keys/values at the end
local_end_index
=
self
.
kv_cache1
[
self
.
block_idx
][
"local_end_index"
]
+
current_end
-
self
.
kv_cache1
[
self
.
block_idx
][
"global_end_index"
]
-
num_evicted_tokens
local_start_index
=
local_end_index
-
num_new_tokens
self
.
kv_cache1
[
self
.
block_idx
][
"k"
][
local_start_index
:
local_end_index
]
=
k
self
.
kv_cache1
[
self
.
block_idx
][
"v"
][
local_start_index
:
local_end_index
]
=
v
else
:
# Assign new keys/values directly up to current_end
local_end_index
=
self
.
kv_cache1
[
self
.
block_idx
][
"local_end_index"
]
+
current_end
-
self
.
kv_cache1
[
self
.
block_idx
][
"global_end_index"
]
local_start_index
=
local_end_index
-
num_new_tokens
self
.
kv_cache1
[
self
.
block_idx
][
"k"
][
local_start_index
:
local_end_index
]
=
k
self
.
kv_cache1
[
self
.
block_idx
][
"v"
][
local_start_index
:
local_end_index
]
=
v
attn_k
=
self
.
kv_cache1
[
self
.
block_idx
][
"k"
][
max
(
0
,
local_end_index
-
self
.
max_attention_size
)
:
local_end_index
]
attn_v
=
self
.
kv_cache1
[
self
.
block_idx
][
"v"
][
max
(
0
,
local_end_index
-
self
.
max_attention_size
)
:
local_end_index
]
self
.
kv_cache1
[
self
.
block_idx
][
"local_end_index"
]
=
local_end_index
self
.
kv_cache1
[
self
.
block_idx
][
"global_end_index"
]
=
current_end
k_lens
=
torch
.
empty_like
(
seq_lens
).
fill_
(
attn_k
.
size
(
0
))
cu_seqlens_q
,
cu_seqlens_k
=
self
.
_calculate_q_k_len
(
q
,
k_lens
=
k_lens
)
if
self
.
clean_cuda_cache
:
del
freqs_i
,
norm1_out
,
norm1_weight
,
norm1_bias
torch
.
cuda
.
empty_cache
()
if
self
.
config
[
"seq_parallel"
]:
attn_out
=
phase
.
self_attn_1_parallel
.
apply
(
q
=
q
,
k
=
attn_k
,
v
=
attn_v
,
img_qkv_len
=
q
.
shape
[
0
],
cu_seqlens_qkv
=
cu_seqlens_q
,
attention_module
=
phase
.
self_attn_1
,
seq_p_group
=
self
.
seq_p_group
,
model_cls
=
self
.
config
[
"model_cls"
],
)
else
:
attn_out
=
phase
.
self_attn_1
.
apply
(
q
=
q
,
k
=
attn_k
,
v
=
attn_v
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_kv
=
cu_seqlens_k
,
max_seqlen_q
=
q
.
size
(
0
),
max_seqlen_kv
=
attn_k
.
size
(
0
),
model_cls
=
self
.
config
[
"model_cls"
],
)
y
=
phase
.
self_attn_o
.
apply
(
attn_out
)
if
self
.
clean_cuda_cache
:
del
q
,
k
,
v
,
attn_out
torch
.
cuda
.
empty_cache
()
return
y
def
infer_cross_attn_with_kvcache
(
self
,
phase
,
x
,
context
,
y_out
,
gate_msa
):
num_frames
=
gate_msa
.
shape
[
0
]
frame_seqlen
=
x
.
shape
[
0
]
//
gate_msa
.
shape
[
0
]
x
.
add_
((
y_out
.
unflatten
(
dim
=
0
,
sizes
=
(
num_frames
,
frame_seqlen
))
*
gate_msa
).
flatten
(
0
,
1
))
norm3_out
=
phase
.
norm3
.
apply
(
x
)
n
,
d
=
self
.
num_heads
,
self
.
head_dim
q
=
phase
.
cross_attn_q
.
apply
(
norm3_out
)
q
=
phase
.
cross_attn_norm_q
.
apply
(
q
).
view
(
-
1
,
n
,
d
)
if
not
self
.
crossattn_cache
[
self
.
block_idx
][
"is_init"
]:
self
.
crossattn_cache
[
self
.
block_idx
][
"is_init"
]
=
True
k
=
phase
.
cross_attn_k
.
apply
(
context
)
k
=
phase
.
cross_attn_norm_k
.
apply
(
k
).
view
(
-
1
,
n
,
d
)
v
=
phase
.
cross_attn_v
.
apply
(
context
)
v
=
v
.
view
(
-
1
,
n
,
d
)
self
.
crossattn_cache
[
self
.
block_idx
][
"k"
]
=
k
self
.
crossattn_cache
[
self
.
block_idx
][
"v"
]
=
v
else
:
k
=
self
.
crossattn_cache
[
self
.
block_idx
][
"k"
]
v
=
self
.
crossattn_cache
[
self
.
block_idx
][
"v"
]
cu_seqlens_q
,
cu_seqlens_k
=
self
.
_calculate_q_k_len
(
q
,
k_lens
=
torch
.
tensor
([
k
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
k
.
device
),
)
attn_out
=
phase
.
cross_attn_1
.
apply
(
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_kv
=
cu_seqlens_k
,
max_seqlen_q
=
q
.
size
(
0
),
max_seqlen_kv
=
k
.
size
(
0
),
model_cls
=
self
.
config
[
"model_cls"
],
)
attn_out
=
phase
.
cross_attn_o
.
apply
(
attn_out
)
if
self
.
clean_cuda_cache
:
del
q
,
k
,
v
,
norm3_out
,
context
,
context_img
torch
.
cuda
.
empty_cache
()
return
x
,
attn_out
def
infer_action_model
(
self
,
phase
,
x
,
grid_sizes
,
seq_lens
,
mouse_condition
=
None
,
keyboard_condition
=
None
,
is_causal
=
False
,
use_rope_keyboard
=
True
):
tt
,
th
,
tw
=
grid_sizes
current_start
=
self
.
scheduler
.
seg_index
*
self
.
num_frame_per_block
start_frame
=
current_start
B
,
N_frames
,
C
=
keyboard_condition
.
shape
assert
tt
*
th
*
tw
==
x
.
shape
[
0
]
assert
((
N_frames
-
1
)
+
self
.
vae_time_compression_ratio
)
%
self
.
vae_time_compression_ratio
==
0
N_feats
=
int
((
N_frames
-
1
)
/
self
.
vae_time_compression_ratio
)
+
1
# Defined freqs_cis early so it's available for both mouse and keyboard
freqs_cis
=
(
self
.
freqs_cos
,
self
.
freqs_sin
)
cond1
=
N_feats
==
tt
cond2
=
is_causal
and
not
self
.
kv_cache_mouse
cond3
=
(
N_frames
-
1
)
//
self
.
vae_time_compression_ratio
+
1
==
current_start
+
self
.
num_frame_per_block
assert
(
cond1
and
((
cond2
)
or
not
is_causal
))
or
(
cond3
and
is_causal
)
x
=
x
.
unsqueeze
(
0
)
if
self
.
enable_mouse
and
mouse_condition
is
not
None
:
hidden_states
=
rearrange
(
x
,
"B (T S) C -> (B S) T C"
,
T
=
tt
,
S
=
th
*
tw
)
# 65*272*480 -> 17*(272//16)*(480//16) -> 8670
B
,
N_frames
,
C
=
mouse_condition
.
shape
else
:
hidden_states
=
x
pad_t
=
self
.
vae_time_compression_ratio
*
self
.
windows_size
if
self
.
enable_mouse
and
mouse_condition
is
not
None
:
pad
=
mouse_condition
[:,
0
:
1
,
:].
expand
(
-
1
,
pad_t
,
-
1
)
mouse_condition
=
torch
.
cat
([
pad
,
mouse_condition
],
dim
=
1
)
if
is_causal
and
self
.
kv_cache_mouse
is
not
None
:
mouse_condition
=
mouse_condition
[:,
self
.
vae_time_compression_ratio
*
(
N_feats
-
self
.
num_frame_per_block
-
self
.
windows_size
)
+
pad_t
:,
:]
group_mouse
=
[
mouse_condition
[:,
self
.
vae_time_compression_ratio
*
(
i
-
self
.
windows_size
)
+
pad_t
:
i
*
self
.
vae_time_compression_ratio
+
pad_t
,
:]
for
i
in
range
(
self
.
num_frame_per_block
)
]
else
:
group_mouse
=
[
mouse_condition
[:,
self
.
vae_time_compression_ratio
*
(
i
-
self
.
windows_size
)
+
pad_t
:
i
*
self
.
vae_time_compression_ratio
+
pad_t
,
:]
for
i
in
range
(
N_feats
)]
group_mouse
=
torch
.
stack
(
group_mouse
,
dim
=
1
)
S
=
th
*
tw
group_mouse
=
group_mouse
.
unsqueeze
(
-
1
).
expand
(
B
,
self
.
num_frame_per_block
,
pad_t
,
C
,
S
)
group_mouse
=
group_mouse
.
permute
(
0
,
4
,
1
,
2
,
3
).
reshape
(
B
*
S
,
self
.
num_frame_per_block
,
pad_t
*
C
)
group_mouse
=
torch
.
cat
([
hidden_states
,
group_mouse
],
dim
=-
1
)
# mouse_mlp
# 注释:Batch维度不可避免,因此用 torch.nn.functional
group_mouse
=
torch
.
nn
.
functional
.
linear
(
group_mouse
,
phase
.
mouse_mlp_0
.
weight
.
T
,
phase
.
mouse_mlp_0
.
bias
)
group_mouse
=
torch
.
nn
.
functional
.
gelu
(
group_mouse
,
approximate
=
"tanh"
)
group_mouse
=
torch
.
nn
.
functional
.
linear
(
group_mouse
,
phase
.
mouse_mlp_2
.
weight
.
T
,
phase
.
mouse_mlp_2
.
bias
)
group_mouse
=
torch
.
nn
.
functional
.
layer_norm
(
group_mouse
,
(
group_mouse
.
shape
[
-
1
],),
phase
.
mouse_mlp_3
.
weight
.
T
,
phase
.
mouse_mlp_3
.
bias
,
1e-5
)
# qkvc
mouse_qkv
=
torch
.
nn
.
functional
.
linear
(
group_mouse
,
phase
.
t_qkv
.
weight
.
T
)
q0
,
k0
,
v
=
rearrange
(
mouse_qkv
,
"B L (K H D) -> K B L H D"
,
K
=
3
,
H
=
self
.
heads_num
)
# BHW F H C # torch.Size([880, 3, 16, 64])
q
=
q0
*
torch
.
rsqrt
(
q0
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
+
1e-6
)
k
=
k0
*
torch
.
rsqrt
(
k0
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
+
1e-6
)
q
,
k
=
apply_rotary_emb
(
q
,
k
,
freqs_cis
,
start_offset
=
start_frame
,
head_first
=
False
)
## TODO: adding cache here
if
is_causal
:
current_end
=
current_start
+
q
.
shape
[
1
]
assert
q
.
shape
[
1
]
==
self
.
num_frame_per_block
sink_size
=
0
max_attention_size
=
self
.
local_attn_size
sink_tokens
=
sink_size
*
1
kv_cache_size
=
self
.
kv_cache_mouse
[
self
.
block_idx
][
"k"
].
shape
[
1
]
num_new_tokens
=
q
.
shape
[
1
]
if
(
current_end
>
self
.
kv_cache_mouse
[
self
.
block_idx
][
"global_end_index"
].
item
())
and
(
num_new_tokens
+
self
.
kv_cache_mouse
[
self
.
block_idx
][
"local_end_index"
].
item
()
>
kv_cache_size
):
num_evicted_tokens
=
num_new_tokens
+
self
.
kv_cache_mouse
[
self
.
block_idx
][
"local_end_index"
].
item
()
-
kv_cache_size
num_rolled_tokens
=
self
.
kv_cache_mouse
[
self
.
block_idx
][
"local_end_index"
].
item
()
-
num_evicted_tokens
-
sink_tokens
self
.
kv_cache_mouse
[
self
.
block_idx
][
"k"
][:,
sink_tokens
:
sink_tokens
+
num_rolled_tokens
]
=
self
.
kv_cache_mouse
[
self
.
block_idx
][
"k"
][
:,
sink_tokens
+
num_evicted_tokens
:
sink_tokens
+
num_evicted_tokens
+
num_rolled_tokens
].
clone
()
self
.
kv_cache_mouse
[
self
.
block_idx
][
"v"
][:,
sink_tokens
:
sink_tokens
+
num_rolled_tokens
]
=
self
.
kv_cache_mouse
[
self
.
block_idx
][
"v"
][
:,
sink_tokens
+
num_evicted_tokens
:
sink_tokens
+
num_evicted_tokens
+
num_rolled_tokens
].
clone
()
# Insert the new keys/values at the end
local_end_index
=
self
.
kv_cache_mouse
[
self
.
block_idx
][
"local_end_index"
].
item
()
+
current_end
-
self
.
kv_cache_mouse
[
self
.
block_idx
][
"global_end_index"
].
item
()
-
num_evicted_tokens
local_start_index
=
local_end_index
-
num_new_tokens
else
:
local_end_index
=
self
.
kv_cache_mouse
[
self
.
block_idx
][
"local_end_index"
].
item
()
+
current_end
-
self
.
kv_cache_mouse
[
self
.
block_idx
][
"global_end_index"
].
item
()
local_start_index
=
local_end_index
-
num_new_tokens
self
.
kv_cache_mouse
[
self
.
block_idx
][
"k"
][:,
local_start_index
:
local_end_index
]
=
k
self
.
kv_cache_mouse
[
self
.
block_idx
][
"v"
][:,
local_start_index
:
local_end_index
]
=
v
attn_k
=
self
.
kv_cache_mouse
[
self
.
block_idx
][
"k"
][:,
max
(
0
,
local_end_index
-
max_attention_size
)
:
local_end_index
]
attn_v
=
self
.
kv_cache_mouse
[
self
.
block_idx
][
"v"
][:,
max
(
0
,
local_end_index
-
max_attention_size
)
:
local_end_index
]
attn
=
flash_attn_interface
.
flash_attn_func
(
q
,
attn_k
,
attn_v
,
)
self
.
kv_cache_mouse
[
self
.
block_idx
][
"global_end_index"
].
fill_
(
current_end
)
self
.
kv_cache_mouse
[
self
.
block_idx
][
"local_end_index"
].
fill_
(
local_end_index
)
else
:
attn
=
flash_attn_func
(
q
,
k
,
v
,
)
# Compute cu_squlens and max_seqlen for flash attention
# qk norm
attn
=
rearrange
(
attn
,
"(b S) T h d -> b (T S) (h d)"
,
b
=
B
)
hidden_states
=
rearrange
(
x
,
"(B S) T C -> B (T S) C"
,
B
=
B
)
attn
=
phase
.
proj_mouse
.
apply
(
attn
[
0
]).
unsqueeze
(
0
)
hidden_states
=
hidden_states
+
attn
if
self
.
enable_keyboard
and
keyboard_condition
is
not
None
:
pad
=
keyboard_condition
[:,
0
:
1
,
:].
expand
(
-
1
,
pad_t
,
-
1
)
keyboard_condition
=
torch
.
cat
([
pad
,
keyboard_condition
],
dim
=
1
)
if
is_causal
and
self
.
kv_cache_keyboard
is
not
None
:
keyboard_condition
=
keyboard_condition
[
:,
self
.
vae_time_compression_ratio
*
(
N_feats
-
self
.
num_frame_per_block
-
self
.
windows_size
)
+
pad_t
:,
:
]
# keyboard_condition[:, self.vae_time_compression_ratio*(start_frame - self.windows_size) + pad_t:start_frame * self.vae_time_compression_ratio + pad_t,:]
keyboard_condition
=
phase
.
keyboard_embed_0
.
apply
(
keyboard_condition
[
0
])
keyboard_condition
=
torch
.
nn
.
functional
.
silu
(
keyboard_condition
)
keyboard_condition
=
phase
.
keyboard_embed_2
.
apply
(
keyboard_condition
).
unsqueeze
(
0
)
group_keyboard
=
[
keyboard_condition
[:,
self
.
vae_time_compression_ratio
*
(
i
-
self
.
windows_size
)
+
pad_t
:
i
*
self
.
vae_time_compression_ratio
+
pad_t
,
:]
for
i
in
range
(
self
.
num_frame_per_block
)
]
else
:
keyboard_condition
=
phase
.
keyboard_embed_0
.
apply
(
keyboard_condition
[
0
])
keyboard_condition
=
torch
.
nn
.
functional
.
silu
(
keyboard_condition
)
keyboard_condition
=
phase
.
keyboard_embed_2
.
apply
(
keyboard_condition
).
unsqueeze
(
0
)
group_keyboard
=
[
keyboard_condition
[:,
self
.
vae_time_compression_ratio
*
(
i
-
self
.
windows_size
)
+
pad_t
:
i
*
self
.
vae_time_compression_ratio
+
pad_t
,
:]
for
i
in
range
(
N_feats
)]
group_keyboard
=
torch
.
stack
(
group_keyboard
,
dim
=
1
)
# B F RW C
group_keyboard
=
group_keyboard
.
reshape
(
shape
=
(
group_keyboard
.
shape
[
0
],
group_keyboard
.
shape
[
1
],
-
1
))
# apply cross attn
mouse_q
=
phase
.
mouse_attn_q
.
apply
(
hidden_states
[
0
]).
unsqueeze
(
0
)
keyboard_kv
=
phase
.
keyboard_attn_kv
.
apply
(
group_keyboard
[
0
]).
unsqueeze
(
0
)
B
,
L
,
HD
=
mouse_q
.
shape
D
=
HD
//
self
.
heads_num
q
=
mouse_q
.
view
(
B
,
L
,
self
.
heads_num
,
D
)
B
,
L
,
KHD
=
keyboard_kv
.
shape
k
,
v
=
keyboard_kv
.
view
(
B
,
L
,
2
,
self
.
heads_num
,
D
).
permute
(
2
,
0
,
1
,
3
,
4
)
# Compute cu_squlens and max_seqlen for flash attention
# qk norm
q
=
q
*
torch
.
rsqrt
(
q
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
+
1e-6
)
k
=
k
*
torch
.
rsqrt
(
k
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
+
1e-6
)
S
=
th
*
tw
assert
S
==
880
# position embed
if
use_rope_keyboard
:
B
,
TS
,
H
,
D
=
q
.
shape
T_
=
TS
//
S
q
=
q
.
view
(
B
,
T_
,
S
,
H
,
D
).
transpose
(
1
,
2
).
reshape
(
B
*
S
,
T_
,
H
,
D
)
q
,
k
=
apply_rotary_emb
(
q
,
k
,
freqs_cis
,
start_offset
=
start_frame
,
head_first
=
False
)
k1
,
k2
,
k3
,
k4
=
k
.
shape
k
=
k
.
expand
(
S
,
k2
,
k3
,
k4
)
v
=
v
.
expand
(
S
,
k2
,
k3
,
k4
)
if
is_causal
:
current_end
=
current_start
+
k
.
shape
[
1
]
assert
k
.
shape
[
1
]
==
self
.
num_frame_per_block
sink_size
=
0
max_attention_size
=
self
.
local_attn_size
sink_tokens
=
sink_size
*
1
kv_cache_size
=
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"k"
].
shape
[
1
]
num_new_tokens
=
k
.
shape
[
1
]
if
(
current_end
>
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"global_end_index"
].
item
())
and
(
num_new_tokens
+
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"local_end_index"
].
item
()
>
kv_cache_size
):
num_evicted_tokens
=
num_new_tokens
+
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"local_end_index"
].
item
()
-
kv_cache_size
num_rolled_tokens
=
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"local_end_index"
].
item
()
-
num_evicted_tokens
-
sink_tokens
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"k"
][:,
sink_tokens
:
sink_tokens
+
num_rolled_tokens
]
=
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"k"
][
:,
sink_tokens
+
num_evicted_tokens
:
sink_tokens
+
num_evicted_tokens
+
num_rolled_tokens
].
clone
()
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"v"
][:,
sink_tokens
:
sink_tokens
+
num_rolled_tokens
]
=
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"v"
][
:,
sink_tokens
+
num_evicted_tokens
:
sink_tokens
+
num_evicted_tokens
+
num_rolled_tokens
].
clone
()
# Insert the new keys/values at the end
local_end_index
=
(
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"local_end_index"
].
item
()
+
current_end
-
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"global_end_index"
].
item
()
-
num_evicted_tokens
)
local_start_index
=
local_end_index
-
num_new_tokens
else
:
local_end_index
=
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"local_end_index"
].
item
()
+
current_end
-
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"global_end_index"
].
item
()
local_start_index
=
local_end_index
-
num_new_tokens
assert
k
.
shape
[
0
]
==
880
# BS == 1 or the cache should not be saved/ load method should be modified
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"k"
][:,
local_start_index
:
local_end_index
]
=
k
[:
1
]
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"v"
][:,
local_start_index
:
local_end_index
]
=
v
[:
1
]
if
FLASH_ATTN_3_AVAILABLE
:
attn_k
=
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"k"
][:,
max
(
0
,
local_end_index
-
max_attention_size
)
:
local_end_index
].
repeat
(
S
,
1
,
1
,
1
)
attn_v
=
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"v"
][:,
max
(
0
,
local_end_index
-
max_attention_size
)
:
local_end_index
].
repeat
(
S
,
1
,
1
,
1
)
attn
=
flash_attn_interface
.
flash_attn_func
(
q
,
attn_k
,
attn_v
,
)
else
:
attn
=
flash_attn_func
(
q
,
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"k"
][
max
(
0
,
local_end_index
-
max_attention_size
)
:
local_end_index
].
repeat
(
S
,
1
,
1
,
1
),
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"v"
][
max
(
0
,
local_end_index
-
max_attention_size
)
:
local_end_index
].
repeat
(
S
,
1
,
1
,
1
),
)
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"global_end_index"
].
fill_
(
current_end
)
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"local_end_index"
].
fill_
(
local_end_index
)
else
:
attn
=
flash_attn_func
(
q
,
k
,
v
,
causal
=
False
,
)
attn
=
rearrange
(
attn
,
"(B S) T H D -> B (T S) (H D)"
,
S
=
S
)
else
:
if
is_causal
:
current_start
=
start_frame
current_end
=
current_start
+
k
.
shape
[
1
]
assert
k
.
shape
[
1
]
==
self
.
num_frame_per_block
sink_size
=
0
local_attn_size
=
self
.
local_attn_size
max_attention_size
=
self
.
local_attn_size
sink_tokens
=
sink_size
*
1
kv_cache_size
=
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"k"
].
shape
[
1
]
num_new_tokens
=
k
.
shape
[
1
]
if
(
current_end
>
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"global_end_index"
].
item
())
and
(
num_new_tokens
+
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"local_end_index"
].
item
()
>
kv_cache_size
):
num_evicted_tokens
=
num_new_tokens
+
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"local_end_index"
].
item
()
-
kv_cache_size
num_rolled_tokens
=
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"local_end_index"
].
item
()
-
num_evicted_tokens
-
sink_tokens
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"k"
][:,
sink_tokens
:
sink_tokens
+
num_rolled_tokens
]
=
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"k"
][
:,
sink_tokens
+
num_evicted_tokens
:
sink_tokens
+
num_evicted_tokens
+
num_rolled_tokens
].
clone
()
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"v"
][:,
sink_tokens
:
sink_tokens
+
num_rolled_tokens
]
=
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"v"
][
:,
sink_tokens
+
num_evicted_tokens
:
sink_tokens
+
num_evicted_tokens
+
num_rolled_tokens
].
clone
()
# Insert the new keys/values at the end
local_end_index
=
(
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"local_end_index"
].
item
()
+
current_end
-
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"global_end_index"
].
item
()
-
num_evicted_tokens
)
local_start_index
=
local_end_index
-
num_new_tokens
else
:
local_end_index
=
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"local_end_index"
].
item
()
+
current_end
-
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"global_end_index"
].
item
()
local_start_index
=
local_end_index
-
num_new_tokens
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"k"
][:,
local_start_index
:
local_end_index
]
=
k
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"v"
][:,
local_start_index
:
local_end_index
]
=
v
attn
=
flash_attn_func
(
q
,
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"k"
][:,
max
(
0
,
local_end_index
-
max_attention_size
)
:
local_end_index
],
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"v"
][:,
max
(
0
,
local_end_index
-
max_attention_size
)
:
local_end_index
],
)
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"global_end_index"
].
fill_
(
current_end
)
self
.
kv_cache_keyboard
[
self
.
block_idx
][
"local_end_index"
].
fill_
(
local_end_index
)
else
:
attn
=
flash_attn_func
(
q
,
k
,
v
,
)
attn
=
rearrange
(
attn
,
"B L H D -> B L (H D)"
)
attn
=
phase
.
proj_keyboard
.
apply
(
attn
[
0
]).
unsqueeze
(
0
)
hidden_states
=
hidden_states
+
attn
hidden_states
=
hidden_states
.
squeeze
(
0
)
return
hidden_states
def
infer_ffn
(
self
,
phase
,
x
,
c_shift_msa
,
c_scale_msa
):
num_frames
=
c_shift_msa
.
shape
[
0
]
frame_seqlen
=
x
.
shape
[
0
]
//
c_shift_msa
.
shape
[
0
]
x
=
phase
.
norm2
.
apply
(
x
).
unsqueeze
(
0
)
x
=
x
.
unflatten
(
dim
=
1
,
sizes
=
(
num_frames
,
frame_seqlen
))
c_scale_msa
=
c_scale_msa
.
unsqueeze
(
0
)
c_shift_msa
=
c_shift_msa
.
unsqueeze
(
0
)
x
=
x
*
(
1
+
c_scale_msa
)
+
c_shift_msa
x
=
x
.
flatten
(
1
,
2
).
squeeze
(
0
)
y
=
phase
.
ffn_0
.
apply
(
x
)
y
=
torch
.
nn
.
functional
.
gelu
(
y
,
approximate
=
"tanh"
)
y
=
phase
.
ffn_2
.
apply
(
y
)
return
y
def
post_process
(
self
,
x
,
y
,
c_gate_msa
,
pre_infer_out
=
None
):
x
=
x
+
y
*
c_gate_msa
[
0
]
x
=
x
.
squeeze
(
0
)
return
x
def
infer_block_witch_kvcache
(
self
,
block
,
x
,
pre_infer_out
):
if
hasattr
(
block
.
compute_phases
[
0
],
"before_proj"
):
x
=
block
.
compute_phases
[
0
].
before_proj
.
apply
(
x
)
+
pre_infer_out
.
x
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
self
.
pre_process
(
block
.
compute_phases
[
0
].
modulation
,
pre_infer_out
.
embed0
,
)
y_out
=
self
.
infer_self_attn_with_kvcache
(
block
.
compute_phases
[
0
],
pre_infer_out
.
grid_sizes
.
tensor
,
x
,
pre_infer_out
.
seq_lens
,
pre_infer_out
.
freqs
,
shift_msa
,
scale_msa
,
)
x
,
attn_out
=
self
.
infer_cross_attn_with_kvcache
(
block
.
compute_phases
[
1
],
x
,
pre_infer_out
.
context
,
y_out
,
gate_msa
,
)
x
=
x
+
attn_out
if
len
(
block
.
compute_phases
)
==
4
:
if
self
.
config
[
"mode"
]
!=
"templerun"
:
x
=
self
.
infer_action_model
(
phase
=
block
.
compute_phases
[
2
],
x
=
x
,
grid_sizes
=
pre_infer_out
.
grid_sizes
.
tensor
[
0
],
seq_lens
=
pre_infer_out
.
seq_lens
,
mouse_condition
=
pre_infer_out
.
conditional_dict
[
"mouse_cond"
],
keyboard_condition
=
pre_infer_out
.
conditional_dict
[
"keyboard_cond"
],
is_causal
=
True
,
use_rope_keyboard
=
True
,
)
else
:
x
=
self
.
infer_action_model
(
phase
=
block
.
compute_phases
[
2
],
x
=
x
,
grid_sizes
=
pre_infer_out
.
grid_sizes
.
tensor
[
0
],
seq_lens
=
pre_infer_out
.
seq_lens
,
keyboard_condition
=
pre_infer_out
.
conditional_dict
[
"keyboard_cond"
],
is_causal
=
True
,
use_rope_keyboard
=
True
,
)
y
=
self
.
infer_ffn
(
block
.
compute_phases
[
3
],
x
,
c_shift_msa
,
c_scale_msa
)
elif
len
(
block
.
compute_phases
)
==
3
:
y
=
self
.
infer_ffn
(
block
.
compute_phases
[
2
],
x
,
c_shift_msa
,
c_scale_msa
)
x
=
self
.
post_process
(
x
,
y
,
c_gate_msa
,
pre_infer_out
)
return
x
def
infer_non_blocks
(
self
,
weights
,
x
,
e
):
num_frames
=
e
.
shape
[
0
]
frame_seqlen
=
x
.
shape
[
0
]
//
e
.
shape
[
0
]
e
=
e
.
unsqueeze
(
0
).
unsqueeze
(
2
)
x
=
weights
.
norm
.
apply
(
x
).
unsqueeze
(
0
)
x
=
x
.
unflatten
(
dim
=
1
,
sizes
=
(
num_frames
,
frame_seqlen
))
modulation
=
weights
.
head_modulation
.
tensor
e
=
(
modulation
.
unsqueeze
(
1
)
+
e
).
chunk
(
2
,
dim
=
2
)
x
=
x
*
(
1
+
e
[
1
])
+
e
[
0
]
x
=
torch
.
nn
.
functional
.
linear
(
x
,
weights
.
head
.
weight
.
T
,
weights
.
head
.
bias
)
if
self
.
clean_cuda_cache
:
del
e
torch
.
cuda
.
empty_cache
()
return
x
lightx2v/models/networks/wan/infer/module_io.py
View file @
9826b8ca
...
@@ -20,3 +20,4 @@ class WanPreInferModuleOutput:
...
@@ -20,3 +20,4 @@ class WanPreInferModuleOutput:
freqs
:
torch
.
Tensor
freqs
:
torch
.
Tensor
context
:
torch
.
Tensor
context
:
torch
.
Tensor
adapter_args
:
Dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
adapter_args
:
Dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
conditional_dict
:
Dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
lightx2v/models/networks/wan/matrix_game2_model.py
0 → 100644
View file @
9826b8ca
import
json
import
os
import
torch
from
safetensors
import
safe_open
from
lightx2v.models.networks.wan.infer.matrix_game2.pre_infer
import
WanMtxg2PreInfer
from
lightx2v.models.networks.wan.infer.matrix_game2.transformer_infer
import
WanMtxg2TransformerInfer
from
lightx2v.models.networks.wan.infer.post_infer
import
WanPostInfer
from
lightx2v.models.networks.wan.sf_model
import
WanSFModel
from
lightx2v.models.networks.wan.weights.matrix_game2.pre_weights
import
WanMtxg2PreWeights
from
lightx2v.models.networks.wan.weights.matrix_game2.transformer_weights
import
WanActionTransformerWeights
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.utils
import
*
class
WanSFMtxg2Model
(
WanSFModel
):
pre_weight_class
=
WanMtxg2PreWeights
transformer_weight_class
=
WanActionTransformerWeights
def
__init__
(
self
,
model_path
,
config
,
device
):
super
().
__init__
(
model_path
,
config
,
device
)
def
_load_safetensor_to_dict
(
self
,
file_path
,
unified_dtype
,
sensitive_layer
):
with
safe_open
(
file_path
,
framework
=
"pt"
,
device
=
str
(
self
.
device
))
as
f
:
return
{
key
:
(
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
())
if
unified_dtype
or
all
(
s
not
in
key
for
s
in
sensitive_layer
)
else
f
.
get_tensor
(
key
).
to
(
GET_SENSITIVE_DTYPE
()))
for
key
in
f
.
keys
()}
def
_load_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
file_path
=
os
.
path
.
join
(
self
.
config
[
"model_path"
],
f
"
{
self
.
config
[
'sub_model_folder'
]
}
/
{
self
.
config
[
'sub_model_name'
]
}
"
)
_weight_dict
=
self
.
_load_safetensor_to_dict
(
file_path
,
unified_dtype
,
sensitive_layer
)
weight_dict
=
{}
for
k
,
v
in
_weight_dict
.
items
():
name
=
k
[
6
:]
weight
=
v
.
to
(
torch
.
bfloat16
).
to
(
self
.
device
)
weight_dict
.
update
({
name
:
weight
})
del
_weight_dict
return
weight_dict
def
_init_infer_class
(
self
):
# update config by real model config
with
open
(
os
.
path
.
join
(
self
.
config
[
"model_path"
],
self
.
config
[
"sub_model_folder"
],
"config.json"
))
as
f
:
model_config
=
json
.
load
(
f
)
for
k
in
model_config
.
keys
():
self
.
config
[
k
]
=
model_config
[
k
]
self
.
pre_infer_class
=
WanMtxg2PreInfer
self
.
post_infer_class
=
WanPostInfer
self
.
transformer_infer_class
=
WanMtxg2TransformerInfer
lightx2v/models/networks/wan/sf_model.py
View file @
9826b8ca
...
@@ -11,7 +11,8 @@ from lightx2v.models.networks.wan.model import WanModel
...
@@ -11,7 +11,8 @@ from lightx2v.models.networks.wan.model import WanModel
class
WanSFModel
(
WanModel
):
class
WanSFModel
(
WanModel
):
def
__init__
(
self
,
model_path
,
config
,
device
):
def
__init__
(
self
,
model_path
,
config
,
device
):
super
().
__init__
(
model_path
,
config
,
device
)
super
().
__init__
(
model_path
,
config
,
device
)
self
.
to_cuda
()
if
config
[
"model_cls"
]
not
in
[
"wan2.1_sf_mtxg2"
]:
self
.
to_cuda
()
def
_load_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
def
_load_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
sf_confg
=
self
.
config
[
"sf_config"
]
sf_confg
=
self
.
config
[
"sf_config"
]
...
...
lightx2v/models/networks/wan/weights/matrix_game2/pre_weights.py
0 → 100644
View file @
9826b8ca
from
lightx2v.common.modules.weight_module
import
WeightModule
from
lightx2v.utils.registry_factory
import
(
CONV3D_WEIGHT_REGISTER
,
LN_WEIGHT_REGISTER
,
MM_WEIGHT_REGISTER
,
)
class
WanMtxg2PreWeights
(
WeightModule
):
def
__init__
(
self
,
config
):
super
().
__init__
()
self
.
in_dim
=
config
[
"in_dim"
]
self
.
dim
=
config
[
"dim"
]
self
.
patch_size
=
(
1
,
2
,
2
)
self
.
config
=
config
# patch
self
.
add_module
(
"patch_embedding"
,
CONV3D_WEIGHT_REGISTER
[
"Default"
](
"patch_embedding.weight"
,
"patch_embedding.bias"
,
stride
=
self
.
patch_size
),
)
# time
self
.
add_module
(
"time_embedding_0"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"time_embedding.0.weight"
,
"time_embedding.0.bias"
),
)
self
.
add_module
(
"time_embedding_2"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"time_embedding.2.weight"
,
"time_embedding.2.bias"
),
)
self
.
add_module
(
"time_projection_1"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"time_projection.1.weight"
,
"time_projection.1.bias"
),
)
# img_emb
self
.
add_module
(
"img_emb_0"
,
LN_WEIGHT_REGISTER
[
"Default"
](
"img_emb.proj.0.weight"
,
"img_emb.proj.0.bias"
,
eps
=
1e-5
),
)
self
.
add_module
(
"img_emb_1"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"img_emb.proj.1.weight"
,
"img_emb.proj.1.bias"
),
)
self
.
add_module
(
"img_emb_3"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"img_emb.proj.3.weight"
,
"img_emb.proj.3.bias"
),
)
self
.
add_module
(
"img_emb_4"
,
LN_WEIGHT_REGISTER
[
"Default"
](
"img_emb.proj.4.weight"
,
"img_emb.proj.4.bias"
,
eps
=
1e-5
),
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment