pre_weights.py 1.69 KB
Newer Older
TorynCurtis's avatar
TorynCurtis committed
1
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, LN_WEIGHT_REGISTER, CONV3D_WEIGHT_REGISTER
2
from lightx2v.common.modules.weight_module import WeightModule
helloyongyang's avatar
helloyongyang committed
3

Dongz's avatar
Dongz committed
4

5
class WanPreWeights(WeightModule):
helloyongyang's avatar
helloyongyang committed
6
    def __init__(self, config):
7
        super().__init__()
helloyongyang's avatar
helloyongyang committed
8
9
10
        self.in_dim = config["in_dim"]
        self.dim = config["dim"]
        self.patch_size = (1, 2, 2)
TorynCurtis's avatar
TorynCurtis committed
11
        self.config = config
helloyongyang's avatar
helloyongyang committed
12

13
14
15
16
17
18
19
20
21
22
23
24
        self.add_module("patch_embedding", CONV3D_WEIGHT_REGISTER["Defaultt-Force-BF16"]("patch_embedding.weight", "patch_embedding.bias", stride=self.patch_size))
        self.add_module("text_embedding_0", MM_WEIGHT_REGISTER["Default"]("text_embedding.0.weight", "text_embedding.0.bias"))
        self.add_module("text_embedding_2", MM_WEIGHT_REGISTER["Default"]("text_embedding.2.weight", "text_embedding.2.bias"))
        self.add_module("time_embedding_0", MM_WEIGHT_REGISTER["Default"]("time_embedding.0.weight", "time_embedding.0.bias"))
        self.add_module("time_embedding_2", MM_WEIGHT_REGISTER["Default"]("time_embedding.2.weight", "time_embedding.2.bias"))
        self.add_module("time_projection_1", MM_WEIGHT_REGISTER["Default"]("time_projection.1.weight", "time_projection.1.bias"))

        if config.task == "i2v":
            self.add_module("proj_0", LN_WEIGHT_REGISTER["Default"]("img_emb.proj.0.weight", "img_emb.proj.0.bias", eps=1e-5))
            self.add_module("proj_1", MM_WEIGHT_REGISTER["Default"]("img_emb.proj.1.weight", "img_emb.proj.1.bias"))
            self.add_module("proj_3", MM_WEIGHT_REGISTER["Default"]("img_emb.proj.3.weight", "img_emb.proj.3.bias"))
            self.add_module("proj_4", LN_WEIGHT_REGISTER["Default"]("img_emb.proj.4.weight", "img_emb.proj.4.bias", eps=1e-5))