pre_weights.py 1.36 KB
Newer Older
litzh's avatar
litzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from lightx2v.common.modules.weight_module import WeightModule
from lightx2v.utils.registry_factory import EMBEDDING_WEIGHT_REGISTER, MM_WEIGHT_REGISTER


class Qwen2PreWeights(WeightModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        # connector
        self.add_module(
            "fc1",
            MM_WEIGHT_REGISTER["Default"]("connector.fc1.weight", "connector.fc1.bias"),
        )
        self.add_module(
            "fc2",
            MM_WEIGHT_REGISTER["Default"]("connector.fc2.weight", "connector.fc2.bias"),
        )
        # language_model
        self.add_module(
            "lm_head",
            MM_WEIGHT_REGISTER["Default"]("language_model.lm_head.weight"),
        )
        self.add_module(
            "embed_tokens",
            EMBEDDING_WEIGHT_REGISTER["Default"]("language_model.model.embed_tokens.weight"),
        )
        # vae2llm
        self.add_module(
            "vae2llm",
            MM_WEIGHT_REGISTER["Default"]("vae2llm.weight", "vae2llm.bias"),
        )

        # time_embedder
        self.add_module(
            "mlp_0",
            MM_WEIGHT_REGISTER["Default"]("time_embedder.mlp.0.weight", "time_embedder.mlp.0.bias"),
        )
        self.add_module(
            "mlp_2",
            MM_WEIGHT_REGISTER["Default"]("time_embedder.mlp.2.weight", "time_embedder.mlp.2.bias"),
        )