post_weights.py 739 Bytes
Newer Older
litzh's avatar
litzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
from lightx2v.common.modules.weight_module import WeightModule
from lightx2v.utils.registry_factory import (
    MM_WEIGHT_REGISTER,
    RMS_WEIGHT_REGISTER,
)


class Qwen2PostWeights(WeightModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.add_module(
            "norm",
            RMS_WEIGHT_REGISTER["fp32_variance"]("language_model.model.norm.weight"),
        )
        self.add_module(
            "norm_moe_gen",
            RMS_WEIGHT_REGISTER["fp32_variance"]("language_model.model.norm_moe_gen.weight"),
        )
        # llm2vae
        self.add_module(
            "llm2vae",
            MM_WEIGHT_REGISTER["Default"]("llm2vae.weight", "llm2vae.bias"),
        )