post_weights.py 2.15 KB
Newer Older
litzh's avatar
litzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""
Post-weights module for LTX2 transformer model.

This module handles the output processing weights including:
- Scale-shift table
- Output normalization
- Output projection
"""

from lightx2v.common.modules.weight_module import WeightModule
from lightx2v.utils.registry_factory import (
    LN_WEIGHT_REGISTER,
    MM_WEIGHT_REGISTER,
    TENSOR_REGISTER,
)


class LTX2PostWeights(WeightModule):
    """
    Post-weights module for LTX2 transformer.

    Handles all weights after transformer blocks:
    - Video output processing (scale_shift_table, norm_out, proj_out)
    - Audio output processing (if audio is enabled)
    """

    def __init__(self, config):
        """
        Initialize post-weights module.

        Args:
            config: Model configuration dictionary containing:
                - model_type: LTXModelType (AudioVideo, VideoOnly, AudioOnly)
                - inner_dim: Video inner dimension
                - audio_inner_dim: Audio inner dimension (if audio enabled)
                - out_channels: Video output channels
                - audio_out_channels: Audio output channels (if audio enabled)
        """
        super().__init__()
        self.config = config

        self.add_module(
            "scale_shift_table",
            TENSOR_REGISTER["Default"](
                "model.diffusion_model.scale_shift_table",
            ),
        )
        self.add_module("norm_out", LN_WEIGHT_REGISTER["torch"]())
        self.add_module(
            "proj_out",
            MM_WEIGHT_REGISTER["Default"](
                "model.diffusion_model.proj_out.weight",
                "model.diffusion_model.proj_out.bias",
            ),
        )

        self.add_module(
            "audio_scale_shift_table",
            TENSOR_REGISTER["Default"](
                "model.diffusion_model.audio_scale_shift_table",
            ),
        )
        self.add_module("audio_norm_out", LN_WEIGHT_REGISTER["torch"]())
        self.add_module(
            "audio_proj_out",
            MM_WEIGHT_REGISTER["Default"](
                "model.diffusion_model.audio_proj_out.weight",
                "model.diffusion_model.audio_proj_out.bias",
            ),
        )