ar_model.py 9.29 KB
Newer Older
litzh's avatar
litzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
import glob
import os

import torch
from loguru import logger

from lightx2v.models.networks.hunyuan_video.model import HunyuanVideo15Model
from lightx2v.models.networks.worldplay.infer.ar_pre_infer import WorldPlayARPreInfer
from lightx2v.models.networks.worldplay.infer.ar_transformer_infer import WorldPlayARTransformerInfer
from lightx2v.models.networks.worldplay.infer.post_infer import WorldPlayPostInfer
from lightx2v.models.networks.worldplay.weights.post_weights import WorldPlayPostWeights
from lightx2v.models.networks.worldplay.weights.pre_weights import WorldPlayPreWeights
from lightx2v.models.networks.worldplay.weights.transformer_weights import WorldPlayTransformerWeights
from lightx2v.utils.envs import *


class WorldPlayARModel(HunyuanVideo15Model):
    """
    WorldPlay AR (Autoregressive) model with action conditioning and ProPE support.

    Extends HunyuanVideo15Model with:
    - Action conditioning via action_in embedder
    - ProPE (Projective Positional Encoding) for camera pose conditioning
    - KV Cache for autoregressive generation
    - Causal attention mechanism
    - Support for loading separate action model checkpoint

    Key differences from WorldPlayModel (Distill):
    - Uses causal attention instead of bidirectional
    - Implements KV cache for frame-by-frame generation
    - No guidance embedding required
    """

    def __init__(self, model_path, config, device, action_ckpt=None):
        self.action_ckpt = action_ckpt
        super().__init__(model_path, config, device)

    def _init_infer_class(self):
        """Initialize WorldPlay AR-specific inference classes."""
        self.pre_infer_class = WorldPlayARPreInfer
        self.post_infer_class = WorldPlayPostInfer

        if self.config["feature_caching"] == "NoCaching":
            self.transformer_infer_class = WorldPlayARTransformerInfer
        else:
            raise NotImplementedError(f"Feature caching {self.config['feature_caching']} not supported for AR model. AR model requires NoCaching due to KV cache management.")

    def _init_weights(self):
        """Initialize weights including action conditioning weights."""
        unified_dtype = GET_DTYPE() == GET_SENSITIVE_DTYPE()
        sensitive_layer = {}

        if not self.dit_quantized:
            weight_dict = self._load_ckpt(unified_dtype, sensitive_layer)
        else:
            weight_dict = self._load_quant_ckpt(unified_dtype, sensitive_layer)

        # Load action model weights if provided
        if self.action_ckpt is not None:
            action_weight_dict = self._load_action_ckpt(unified_dtype, sensitive_layer)
            weight_dict.update(action_weight_dict)

        self.original_weight_dict = weight_dict
        self.pre_weight = WorldPlayPreWeights(self.config)
        self.transformer_weights = WorldPlayTransformerWeights(self.config)
        self.post_weight = WorldPlayPostWeights(self.config)
        self._apply_weights()

    def _load_action_ckpt(self, unified_dtype, sensitive_layer):
        """Load action model checkpoint."""
        action_ckpt = self.action_ckpt

        if os.path.isdir(action_ckpt):
            safetensors_files = glob.glob(os.path.join(action_ckpt, "*.safetensors"))
        else:
            safetensors_files = [action_ckpt]

        weight_dict = {}
        for file_path in safetensors_files:
            logger.info(f"Loading action weights from {file_path}")
            file_weights = self._load_safetensor_to_dict(file_path, unified_dtype, sensitive_layer)
            weight_dict.update(file_weights)

        return weight_dict

    def _init_infer(self):
        """Initialize inference modules and connect action weights."""
        super()._init_infer()

        # Connect action weights to transformer for ProPE projection
        if hasattr(self.pre_weight, "action_weights") and hasattr(self.transformer_infer, "set_action_weights"):
            self.transformer_infer.set_action_weights(self.pre_weight.action_weights)

    def set_scheduler(self, scheduler):
        """Set scheduler and connect to inference modules."""
        super().set_scheduler(scheduler)

    def init_kv_cache(self):
        """
        Initialize KV cache for autoregressive generation.
        Structure matches HY-WorldPlay original implementation.
        """
        if hasattr(self.transformer_infer, "init_kv_cache"):
            return self.transformer_infer.init_kv_cache()
        return None

    def clear_kv_cache(self):
        """Clear KV cache after generation."""
        if hasattr(self.transformer_infer, "clear_kv_cache"):
            self.transformer_infer.clear_kv_cache()

    def clear_vision_cache(self):
        """Clear only vision cache (keep text cache for next chunk)."""
        if hasattr(self.transformer_infer, "clear_vision_cache"):
            self.transformer_infer.clear_vision_cache()

    # ========== AR-specific inference methods ==========

    @torch.no_grad()
    def infer_txt(self, inputs, cache_txt=True):
        """
        Cache text KV, called once at the beginning of generation.
        Corresponds to original forward_txt().

        Args:
            inputs: Dict containing text_encoder_output, image_encoder_output
            cache_txt: Whether to cache text KV (default True)

        Returns:
            KV cache reference
        """
        # Initialize KV cache if not already done
        if not hasattr(self.transformer_infer, "_kv_cache") or self.transformer_infer._kv_cache is None:
            self.init_kv_cache()

        # Run text-only pre-processing
        infer_module_out = self.pre_infer.infer_txt_only(self.pre_weight, inputs)

        # Cache text KV
        return self.transformer_infer.infer_txt(self.transformer_weights, infer_module_out, cache_txt=cache_txt)

    @torch.no_grad()
    def infer_vision(self, inputs, cache_vision=False):
        """
        Vision inference using cached text KV.
        Corresponds to original forward_vision().

        Args:
            inputs: Dict containing encoder outputs and pose data
            cache_vision: Whether to cache vision KV for context frames

        Returns:
            If cache_vision=True: KV cache reference
            If cache_vision=False: Noise prediction output
        """
        # Store pose data in scheduler if provided
        if "pose_output" in inputs and inputs["pose_output"] is not None:
            pose_output = inputs["pose_output"]
            if "viewmats" in pose_output:
                self.scheduler.viewmats = pose_output["viewmats"]
            if "Ks" in pose_output:
                self.scheduler.Ks = pose_output["Ks"]
            if "action" in pose_output:
                self.scheduler.action = pose_output["action"]

        # Run pre-inference (full, including image)
        infer_module_out = self.pre_infer.infer(self.pre_weight, inputs)

        # Vision inference with KV cache
        output = self.transformer_infer.infer_vision(self.transformer_weights, infer_module_out, cache_vision=cache_vision)

        if cache_vision:
            return output  # Return KV cache
        else:
            # Run post-inference
            return self.post_infer.infer(self.post_weight, output)

    @torch.no_grad()
    def infer(self, inputs):
        """
        Run inference with action and camera pose conditioning.

        Args:
            inputs: Dict containing:
                - text_encoder_output: Text encoder outputs
                - image_encoder_output: Image encoder outputs
                - pose_output (optional): Dict with viewmats, Ks, action
        """
        # Store pose data in scheduler if provided
        if "pose_output" in inputs and inputs["pose_output"] is not None:
            pose_output = inputs["pose_output"]
            if "viewmats" in pose_output:
                self.scheduler.viewmats = pose_output["viewmats"]
            if "Ks" in pose_output:
                self.scheduler.Ks = pose_output["Ks"]
            if "action" in pose_output:
                self.scheduler.action = pose_output["action"]

        # Call parent inference
        super().infer(inputs)

    @torch.no_grad()
    def infer_chunk(self, inputs, chunk_idx, total_chunks):
        """
        Run inference for a single chunk in autoregressive generation.

        Args:
            inputs: Dict containing encoder outputs and pose data
            chunk_idx: Current chunk index (0-indexed)
            total_chunks: Total number of chunks

        Returns:
            Latent tensor for this chunk
        """
        # Store chunk info in scheduler
        self.scheduler.chunk_idx = chunk_idx
        self.scheduler.total_chunks = total_chunks

        # Store pose data if provided
        if "pose_output" in inputs and inputs["pose_output"] is not None:
            pose_output = inputs["pose_output"]
            if "viewmats" in pose_output:
                self.scheduler.viewmats = pose_output["viewmats"]
            if "Ks" in pose_output:
                self.scheduler.Ks = pose_output["Ks"]
            if "action" in pose_output:
                self.scheduler.action = pose_output["action"]

        # Run pre-inference
        infer_module_out = self.pre_infer.infer(self.pre_weight, inputs)

        # Run transformer with KV cache
        x = self.transformer_infer.infer(self.transformer_weights, infer_module_out)

        # Run post-inference
        output = self.post_infer.infer(self.post_weight, x)

        return output