post_weights.py 613 Bytes
Newer Older
PengGao's avatar
PengGao committed
1
from lightx2v.common.modules.weight_module import WeightModule
gushiqiao's avatar
gushiqiao committed
2
from lightx2v.utils.registry_factory import (
PengGao's avatar
PengGao committed
3
    LN_WEIGHT_REGISTER,
gushiqiao's avatar
gushiqiao committed
4
5
6
    MM_WEIGHT_REGISTER,
    TENSOR_REGISTER,
)
TorynCurtis's avatar
TorynCurtis committed
7
8


9
class WanPostWeights(WeightModule):
TorynCurtis's avatar
TorynCurtis committed
10
    def __init__(self, config):
11
        super().__init__()
TorynCurtis's avatar
TorynCurtis committed
12
        self.config = config
gushiqiao's avatar
gushiqiao committed
13
14
15
16
        self.register_parameter(
            "norm",
            LN_WEIGHT_REGISTER["Default"](),
        )
17
18
        self.add_module("head", MM_WEIGHT_REGISTER["Default"]("head.head.weight", "head.head.bias"))
        self.register_parameter("head_modulation", TENSOR_REGISTER["Default"]("head.modulation"))