bi_model.py 9.44 KB
Newer Older
litzh's avatar
litzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import glob
import os

import torch
import torch.distributed as dist
from loguru import logger
from safetensors import safe_open

from lightx2v.models.networks.hunyuan_video.model import HunyuanVideo15Model
from lightx2v.models.networks.worldplay.infer.bi_transformer_infer import WorldPlayBITransformerInfer
from lightx2v.models.networks.worldplay.infer.post_infer import WorldPlayPostInfer
from lightx2v.models.networks.worldplay.infer.pre_infer import WorldPlayPreInfer
from lightx2v.models.networks.worldplay.weights.post_weights import WorldPlayPostWeights
from lightx2v.models.networks.worldplay.weights.pre_weights import WorldPlayPreWeights
from lightx2v.models.networks.worldplay.weights.transformer_weights import WorldPlayTransformerWeights
from lightx2v.utils.envs import *


class WorldPlayBIModel(HunyuanVideo15Model):
    """
    WorldPlay BI (Bidirectional) model with action conditioning and ProPE support.

    Key differences from AR model:
    - Uses bidirectional attention (not causal)
    - No KV cache required
    - Standard 50-step inference
    - Supports classifier-free guidance

    Key differences from Distill model:
    - Standard 50-step inference (not 4-step)
    - Uses guidance_scale for CFG

    Extends HunyuanVideo15Model with:
    - Action conditioning via action_in embedder
    - ProPE (Projective Positional Encoding) for camera pose conditioning
    - Support for loading separate action model checkpoint
    """

    def __init__(self, model_path, config, device, action_ckpt=None):
        self.action_ckpt = action_ckpt
        # BI model needs byt5_in and vision_in weights, so don't remove them
        # This must be set before calling super().__init__() which sets remove_keys
        self._bi_model_keep_all_keys = config.get("use_bi_model_as_main", False)
        super().__init__(model_path, config, device)
        # Override remove_keys if using bi model as main (it has all weights)
        if self._bi_model_keep_all_keys:
            self.remove_keys = []

    def _init_infer_class(self):
        """Initialize WorldPlay BI-specific inference classes."""
        self.pre_infer_class = WorldPlayPreInfer
        self.post_infer_class = WorldPlayPostInfer

        if self.config["feature_caching"] == "NoCaching":
            self.transformer_infer_class = WorldPlayBITransformerInfer
        else:
            # Fall back to base transformer for caching modes
            from lightx2v.models.networks.hunyuan_video.infer.feature_caching.transformer_infer import (
                HunyuanTransformerInferTeaCaching,
                HunyuanVideo15TransformerInferMagCaching,
            )

            if self.config["feature_caching"] == "Mag":
                self.transformer_infer_class = HunyuanVideo15TransformerInferMagCaching
            elif self.config["feature_caching"] == "Tea":
                self.transformer_infer_class = HunyuanTransformerInferTeaCaching
            else:
                raise NotImplementedError(f"Feature caching {self.config['feature_caching']} not supported")

    def _init_weights(self):
        """Initialize weights including action conditioning weights.

        For BI model, the action_ckpt contains the COMPLETE model weights
        (not just action-related weights). When use_bi_model_as_main is True,
        we load directly from action_ckpt instead of merging with base model.
        """
        unified_dtype = GET_DTYPE() == GET_SENSITIVE_DTYPE()
        sensitive_layer = {}

        use_bi_model_as_main = self.config.get("use_bi_model_as_main", False)

        if use_bi_model_as_main and self.action_ckpt is not None:
            # BI model: action_ckpt contains complete model weights
            logger.info("Loading BI model weights directly from action_ckpt (complete model)")
            weight_dict = self._load_action_ckpt(unified_dtype, sensitive_layer)
        else:
            # Legacy mode: load base model and merge action weights
            if not self.dit_quantized:
                weight_dict = self._load_ckpt(unified_dtype, sensitive_layer)
            else:
                weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer)

            # Load action model weights if provided
            if self.action_ckpt is not None:
                action_weight_dict = self._load_action_ckpt(unified_dtype, sensitive_layer)
                weight_dict.update(action_weight_dict)

        self.original_weight_dict = weight_dict
        self.pre_weight = WorldPlayPreWeights(self.config)
        self.transformer_weights = WorldPlayTransformerWeights(self.config)
        self.post_weight = WorldPlayPostWeights(self.config)
        self._apply_weights()

    def _load_action_ckpt(self, unified_dtype, sensitive_layer):
        """Load action model checkpoint.

        For BI model with use_bi_model_as_main=True, this loads the complete model
        including byt5_in and vision_in weights.
        """
        action_ckpt = self.action_ckpt

        if os.path.isdir(action_ckpt):
            safetensors_files = glob.glob(os.path.join(action_ckpt, "*.safetensors"))
        else:
            safetensors_files = [action_ckpt]

        weight_dict = {}
        for file_path in safetensors_files:
            logger.info(f"Loading action weights from {file_path}")
            # Use _load_safetensor_to_dict_no_filter to keep all keys for BI model
            file_weights = self._load_safetensor_to_dict_no_filter(file_path, unified_dtype, sensitive_layer)
            weight_dict.update(file_weights)

        return weight_dict

    def _load_safetensor_to_dict_no_filter(self, file_path, unified_dtype, sensitive_layer):
        """Load safetensor without filtering any keys (for BI model complete weights)."""
        if self.device.type != "cpu" and dist.is_initialized():
            device = dist.get_rank()
        else:
            device = str(self.device)

        with safe_open(file_path, framework="pt", device=device) as f:
            return {key: (f.get_tensor(key).to(GET_DTYPE()) if unified_dtype or all(s not in key for s in sensitive_layer) else f.get_tensor(key).to(GET_SENSITIVE_DTYPE())) for key in f.keys()}

    def _init_infer(self):
        """Initialize inference modules and connect action weights."""
        super()._init_infer()

        # Connect action weights to transformer for ProPE projection
        if hasattr(self.pre_weight, "action_weights") and hasattr(self.transformer_infer, "set_action_weights"):
            self.transformer_infer.set_action_weights(self.pre_weight.action_weights)

    def set_scheduler(self, scheduler):
        """Set scheduler and connect to inference modules."""
        super().set_scheduler(scheduler)

    @torch.no_grad()
    def infer(self, inputs):
        """
        Run inference with action and camera pose conditioning.

        Args:
            inputs: Dict containing:
                - text_encoder_output: Text encoder outputs
                - image_encoder_output: Image encoder outputs
                - pose_output (optional): Dict with viewmats, Ks, action
        """
        # Store pose data in scheduler if provided
        if "pose_output" in inputs and inputs["pose_output"] is not None:
            pose_output = inputs["pose_output"]
            if "viewmats" in pose_output:
                self.scheduler.viewmats = pose_output["viewmats"]
            if "Ks" in pose_output:
                self.scheduler.Ks = pose_output["Ks"]
            if "action" in pose_output:
                self.scheduler.action = pose_output["action"]

        # Call parent inference
        super().infer(inputs)

    @torch.no_grad()
    def infer_bi(self, inputs, context_inputs=None):
        """
        Run bidirectional inference for a chunk.

        This is the main inference method for BI model, supporting:
        - Context frame concatenation for non-first chunks
        - Per-token timestep for context vs current frames
        - Classifier-free guidance

        Args:
            inputs: Dict containing encoder outputs and current chunk data
            context_inputs: Optional dict with context frame data for non-first chunks

        Returns:
            Noise prediction output
        """
        # Note: pose data (viewmats, Ks, action) should already be set in scheduler
        # by the runner before calling this method. Don't override here.

        # Run pre-inference
        infer_module_out = self.pre_infer.infer(self.pre_weight, inputs)

        # If context inputs provided, merge them
        if context_inputs is not None:
            infer_module_out = self._merge_context(infer_module_out, context_inputs)

        # Run transformer inference
        x = self.transformer_infer.infer(self.transformer_weights, infer_module_out)

        # Run post-inference (needs weights and infer_module_out for grid_sizes)
        output = self.post_infer.infer(x, infer_module_out)

        return output

    def _merge_context(self, infer_module_out, context_inputs):
        """
        Merge context frame data with current chunk data.

        For BI model, context frames are concatenated with current frames
        and processed together with bidirectional attention.

        Args:
            infer_module_out: Current chunk inference module output
            context_inputs: Context frame data

        Returns:
            Merged inference module output
        """
        # This will be implemented in the transformer_infer to handle
        # the concatenation of context and current frames
        infer_module_out.context_inputs = context_inputs
        return infer_module_out