pre_weights.py 5.73 KB
Newer Older
1
from lightx2v.common.modules.weight_module import WeightModule
PengGao's avatar
PengGao committed
2
from lightx2v.utils.registry_factory import ATTN_WEIGHT_REGISTER, CONV3D_WEIGHT_REGISTER, LN_WEIGHT_REGISTER, MM_WEIGHT_REGISTER
helloyongyang's avatar
helloyongyang committed
3
4


5
class HunyuanPreWeights(WeightModule):
helloyongyang's avatar
helloyongyang committed
6
    def __init__(self, config):
7
        super().__init__()
helloyongyang's avatar
helloyongyang committed
8
9
        self.config = config

10
        self.add_module("img_in_proj", CONV3D_WEIGHT_REGISTER["Default"]("img_in.proj.weight", "img_in.proj.bias", stride=(1, 2, 2)))
helloyongyang's avatar
helloyongyang committed
11

12
13
14
15
16
        self.add_module("txt_in_input_embedder", MM_WEIGHT_REGISTER["Default"]("txt_in.input_embedder.weight", "txt_in.input_embedder.bias"))
        self.add_module("txt_in_t_embedder_mlp_0", MM_WEIGHT_REGISTER["Default"]("txt_in.t_embedder.mlp.0.weight", "txt_in.t_embedder.mlp.0.bias"))
        self.add_module("txt_in_t_embedder_mlp_2", MM_WEIGHT_REGISTER["Default"]("txt_in.t_embedder.mlp.2.weight", "txt_in.t_embedder.mlp.2.bias"))
        self.add_module("txt_in_c_embedder_linear_1", MM_WEIGHT_REGISTER["Default"]("txt_in.c_embedder.linear_1.weight", "txt_in.c_embedder.linear_1.bias"))
        self.add_module("txt_in_c_embedder_linear_2", MM_WEIGHT_REGISTER["Default"]("txt_in.c_embedder.linear_2.weight", "txt_in.c_embedder.linear_2.bias"))
helloyongyang's avatar
helloyongyang committed
17

18
19
20
        self.add_module(
            "txt_in_individual_token_refiner_blocks_0_norm1",
            LN_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.0.norm1.weight", "txt_in.individual_token_refiner.blocks.0.norm1.bias", eps=1e-6),
Dongz's avatar
Dongz committed
21
        )
22
23
24
        self.add_module(
            "txt_in_individual_token_refiner_blocks_0_self_attn_qkv",
            MM_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.0.self_attn_qkv.weight", "txt_in.individual_token_refiner.blocks.0.self_attn_qkv.bias"),
Dongz's avatar
Dongz committed
25
        )
26
27
28
        self.add_module(
            "txt_in_individual_token_refiner_blocks_0_self_attn_proj",
            MM_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.0.self_attn_proj.weight", "txt_in.individual_token_refiner.blocks.0.self_attn_proj.bias"),
Dongz's avatar
Dongz committed
29
        )
30
31
32
        self.add_module(
            "txt_in_individual_token_refiner_blocks_0_norm2",
            LN_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.0.norm2.weight", "txt_in.individual_token_refiner.blocks.0.norm2.bias", eps=1e-6),
Dongz's avatar
Dongz committed
33
        )
34
35
36
        self.add_module(
            "txt_in_individual_token_refiner_blocks_0_mlp_fc1",
            MM_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.0.mlp.fc1.weight", "txt_in.individual_token_refiner.blocks.0.mlp.fc1.bias"),
Dongz's avatar
Dongz committed
37
        )
38
39
40
        self.add_module(
            "txt_in_individual_token_refiner_blocks_0_mlp_fc2",
            MM_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.0.mlp.fc2.weight", "txt_in.individual_token_refiner.blocks.0.mlp.fc2.bias"),
Dongz's avatar
Dongz committed
41
        )
42
43
44
        self.add_module(
            "txt_in_individual_token_refiner_blocks_0_adaLN_modulation_1",
            MM_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.weight", "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias"),
Dongz's avatar
Dongz committed
45
        )
helloyongyang's avatar
helloyongyang committed
46

47
48
49
        self.add_module(
            "txt_in_individual_token_refiner_blocks_1_norm1",
            LN_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.1.norm1.weight", "txt_in.individual_token_refiner.blocks.1.norm1.bias", eps=1e-6),
Dongz's avatar
Dongz committed
50
        )
51
52
53
        self.add_module(
            "txt_in_individual_token_refiner_blocks_1_self_attn_qkv",
            MM_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.1.self_attn_qkv.weight", "txt_in.individual_token_refiner.blocks.1.self_attn_qkv.bias"),
Dongz's avatar
Dongz committed
54
        )
55
56
57
        self.add_module(
            "txt_in_individual_token_refiner_blocks_1_self_attn_proj",
            MM_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.1.self_attn_proj.weight", "txt_in.individual_token_refiner.blocks.1.self_attn_proj.bias"),
Dongz's avatar
Dongz committed
58
        )
59
60
61
        self.add_module(
            "txt_in_individual_token_refiner_blocks_1_norm2",
            LN_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.1.norm2.weight", "txt_in.individual_token_refiner.blocks.1.norm2.bias", eps=1e-6),
Dongz's avatar
Dongz committed
62
        )
63
64
65
        self.add_module(
            "txt_in_individual_token_refiner_blocks_1_mlp_fc1",
            MM_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.1.mlp.fc1.weight", "txt_in.individual_token_refiner.blocks.1.mlp.fc1.bias"),
Dongz's avatar
Dongz committed
66
        )
67
68
69
        self.add_module(
            "txt_in_individual_token_refiner_blocks_1_mlp_fc2",
            MM_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.1.mlp.fc2.weight", "txt_in.individual_token_refiner.blocks.1.mlp.fc2.bias"),
Dongz's avatar
Dongz committed
70
        )
71
72
73
        self.add_module(
            "txt_in_individual_token_refiner_blocks_1_adaLN_modulation_1",
            MM_WEIGHT_REGISTER["Default"]("txt_in.individual_token_refiner.blocks.1.adaLN_modulation.1.weight", "txt_in.individual_token_refiner.blocks.1.adaLN_modulation.1.bias"),
Dongz's avatar
Dongz committed
74
75
        )

76
77
78
79
80
81
        self.add_module("time_in_mlp_0", MM_WEIGHT_REGISTER["Default"]("time_in.mlp.0.weight", "time_in.mlp.0.bias"))
        self.add_module("time_in_mlp_2", MM_WEIGHT_REGISTER["Default"]("time_in.mlp.2.weight", "time_in.mlp.2.bias"))
        self.add_module("vector_in_in_layer", MM_WEIGHT_REGISTER["Default"]("vector_in.in_layer.weight", "vector_in.in_layer.bias"))
        self.add_module("vector_in_out_layer", MM_WEIGHT_REGISTER["Default"]("vector_in.out_layer.weight", "vector_in.out_layer.bias"))
        self.add_module("guidance_in_mlp_0", MM_WEIGHT_REGISTER["Default"]("guidance_in.mlp.0.weight", "guidance_in.mlp.0.bias"))
        self.add_module("guidance_in_mlp_2", MM_WEIGHT_REGISTER["Default"]("guidance_in.mlp.2.weight", "guidance_in.mlp.2.bias"))
82
83
84

        # attention weights section
        self.add_module("txt_in_attn_1", ATTN_WEIGHT_REGISTER["torch_sdpa"]())