post_weights.py 468 Bytes
Newer Older
1
2
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER, TENSOR_REGISTER
from lightx2v.common.modules.weight_module import WeightModule
TorynCurtis's avatar
TorynCurtis committed
3
4


5
class WanPostWeights(WeightModule):
TorynCurtis's avatar
TorynCurtis committed
6
    def __init__(self, config):
7
        super().__init__()
TorynCurtis's avatar
TorynCurtis committed
8
        self.config = config
9
10
        self.add_module("head", MM_WEIGHT_REGISTER["Default"]("head.head.weight", "head.head.bias"))
        self.register_parameter("head_modulation", TENSOR_REGISTER["Default"]("head.modulation"))