pre_weights.py 2.78 KB
Newer Older
PengGao's avatar
PengGao committed
1
from lightx2v.common.modules.weight_module import WeightModule
gushiqiao's avatar
gushiqiao committed
2
3
from lightx2v.utils.registry_factory import (
    CONV3D_WEIGHT_REGISTER,
PengGao's avatar
PengGao committed
4
5
    LN_WEIGHT_REGISTER,
    MM_WEIGHT_REGISTER,
gushiqiao's avatar
gushiqiao committed
6
    TENSOR_REGISTER,
gushiqiao's avatar
gushiqiao committed
7
)
helloyongyang's avatar
helloyongyang committed
8

Dongz's avatar
Dongz committed
9

10
class WanPreWeights(WeightModule):
helloyongyang's avatar
helloyongyang committed
11
    def __init__(self, config):
12
        super().__init__()
helloyongyang's avatar
helloyongyang committed
13
14
15
        self.in_dim = config["in_dim"]
        self.dim = config["dim"]
        self.patch_size = (1, 2, 2)
TorynCurtis's avatar
TorynCurtis committed
16
        self.config = config
helloyongyang's avatar
helloyongyang committed
17

gushiqiao's avatar
gushiqiao committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
        self.add_module(
            "patch_embedding",
            CONV3D_WEIGHT_REGISTER["Default"]("patch_embedding.weight", "patch_embedding.bias", stride=self.patch_size),
        )
        self.add_module(
            "text_embedding_0",
            MM_WEIGHT_REGISTER["Default"]("text_embedding.0.weight", "text_embedding.0.bias"),
        )
        self.add_module(
            "text_embedding_2",
            MM_WEIGHT_REGISTER["Default"]("text_embedding.2.weight", "text_embedding.2.bias"),
        )
        self.add_module(
            "time_embedding_0",
            MM_WEIGHT_REGISTER["Default"]("time_embedding.0.weight", "time_embedding.0.bias"),
        )
        self.add_module(
            "time_embedding_2",
            MM_WEIGHT_REGISTER["Default"]("time_embedding.2.weight", "time_embedding.2.bias"),
        )
        self.add_module(
            "time_projection_1",
            MM_WEIGHT_REGISTER["Default"]("time_projection.1.weight", "time_projection.1.bias"),
        )
42

gushiqiao's avatar
gushiqiao committed
43
        if config.task in ["i2v", "flf2v"] and config.get("use_image_encoder", True):
gushiqiao's avatar
gushiqiao committed
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
            self.add_module(
                "proj_0",
                LN_WEIGHT_REGISTER["Default"]("img_emb.proj.0.weight", "img_emb.proj.0.bias"),
            )
            self.add_module(
                "proj_1",
                MM_WEIGHT_REGISTER["Default"]("img_emb.proj.1.weight", "img_emb.proj.1.bias"),
            )
            self.add_module(
                "proj_3",
                MM_WEIGHT_REGISTER["Default"]("img_emb.proj.3.weight", "img_emb.proj.3.bias"),
            )
            self.add_module(
                "proj_4",
                LN_WEIGHT_REGISTER["Default"]("img_emb.proj.4.weight", "img_emb.proj.4.bias"),
            )
60
61
62

        if config.model_cls == "wan2.1_distill" and config.get("enable_dynamic_cfg", False):
            self.add_module(
GoatWu's avatar
GoatWu committed
63
64
65
66
67
68
                "cfg_cond_proj_1",
                MM_WEIGHT_REGISTER["Default"]("guidance_embedding.linear_1.weight", "guidance_embedding.linear_1.bias"),
            )
            self.add_module(
                "cfg_cond_proj_2",
                MM_WEIGHT_REGISTER["Default"]("guidance_embedding.linear_2.weight", "guidance_embedding.linear_2.bias"),
69
            )
gushiqiao's avatar
gushiqiao committed
70
71
72
73
74
75

        if config.task == "flf2v":
            self.add_module(
                "emb_pos",
                TENSOR_REGISTER["Default"](f"img_emb.emb_pos"),
            )