"models-2.13.1/official/projects/movinet/README.md" did not exist on "e53ccd8010d8359bd9c3a641101bb0f8c364007d"
post_weights.py 1.41 KB
Newer Older
Watebear's avatar
Watebear committed
1
2
from lightx2v.common.ops.mm.mm_weight import MMWeightTemplate
from lightx2v.common.ops.norm.layer_norm_weight import LNWeightTemplate
PengGao's avatar
PengGao committed
3
from lightx2v.utils.registry_factory import LN_WEIGHT_REGISTER, MM_WEIGHT_REGISTER
Watebear's avatar
Watebear committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31


class CogvideoxPostWeights:
    def __init__(self, config, mm_type="Default"):
        self.config = config
        self.mm_type = mm_type

    def load_weights(self, weight_dict):
        self.norm_out_linear = MM_WEIGHT_REGISTER[self.mm_type]("norm_out.linear.weight", "norm_out.linear.bias")
        self.proj_out = MM_WEIGHT_REGISTER[self.mm_type]("proj_out.weight", "proj_out.bias")
        self.norm_final = LN_WEIGHT_REGISTER[self.mm_type]("norm_final.weight", "norm_final.bias")
        self.norm_out_norm = LN_WEIGHT_REGISTER[self.mm_type]("norm_out.norm.weight", "norm_out.norm.bias", eps=1e-5)

        self.weight_list = [self.norm_out_linear, self.proj_out, self.norm_final, self.norm_out_norm]

        for mm_weight in self.weight_list:
            if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate)):
                mm_weight.load(weight_dict)

    def to_cpu(self):
        for mm_weight in self.weight_list:
            if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate)):
                mm_weight.to_cpu()

    def to_cuda(self):
        for mm_weight in self.weight_list:
            if isinstance(mm_weight, (MMWeightTemplate, LNWeightTemplate)):
                mm_weight.to_cuda()