worldplay_ar_dataset.py 13.4 KB
Newer Older
litzh's avatar
litzh committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
"""
WorldPlay AR Dataset for autoregressive video generation training.

This dataset supports:
- Camera pose (w2c, intrinsic) loading
- Action label extraction
- Image conditioning for I2V
- Chunk-based training with memory window
- I2V masking for conditional generation
"""

import json
import os
import random
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from loguru import logger
from torch.utils.data import Dataset

try:
    from PIL import Image
except ImportError:
    Image = None

try:
    import decord

    decord.bridge.set_bridge("torch")
except ImportError:
    decord = None


class WorldPlayARDataset(Dataset):
    """
    Dataset for WorldPlay AR model training.

    Supports loading video data with camera poses and action labels
    for autoregressive video generation training.

    Args:
        data_root: Root directory containing video data
        meta_file: Path to metadata JSON file
        video_length: Number of frames per video sample
        resolution: Target resolution (height, width)
        chunk_latent_num: Number of latent frames per chunk
        memory_window_size: Size of memory window for AR training
        select_window_out_flag: Whether to use memory window selection
        task: Task type ('t2v' or 'i2v')
        transform: Optional transform to apply to frames
        action_trans_thresh: Translation threshold for action quantization
        action_rot_thresh: Rotation threshold for action quantization
        num_action_classes: Number of discrete action classes (default 81 = 3^4)
    """

    def __init__(
        self,
        data_root: str,
        meta_file: str,
        video_length: int = 125,
        resolution: Tuple[int, int] = (480, 832),
        chunk_latent_num: int = 4,
        memory_window_size: int = 8,
        select_window_out_flag: bool = True,
        task: str = "i2v",
        transform: Optional[Any] = None,
        action_trans_thresh: float = 0.1,
        action_rot_thresh: float = 0.05,
        num_action_classes: int = 81,
    ):
        super().__init__()
        self.data_root = data_root
        self.video_length = video_length
        self.resolution = resolution
        self.chunk_latent_num = chunk_latent_num
        self.memory_window_size = memory_window_size
        self.select_window_out_flag = select_window_out_flag
        self.task = task
        self.transform = transform

        # Action quantization parameters
        self.action_trans_thresh = action_trans_thresh
        self.action_rot_thresh = action_rot_thresh
        self.num_action_classes = num_action_classes

        # Load metadata
        self.samples = self._load_metadata(meta_file)
        logger.info(f"Loaded {len(self.samples)} samples from {meta_file}")

    def _load_metadata(self, meta_file: str) -> List[Dict]:
        """Load metadata from JSON file."""
        with open(meta_file, "r") as f:
            data = json.load(f)

        samples = []
        for item in data:
            sample = {
                "video_path": os.path.join(self.data_root, item["video_path"]),
                "caption": item.get("caption", ""),
            }

            # Camera pose data
            if "w2c" in item:
                sample["w2c"] = item["w2c"]
            if "intrinsic" in item:
                sample["intrinsic"] = item["intrinsic"]
            if "action" in item:
                sample["action"] = item["action"]

            # Image conditioning for I2V
            if "image_cond" in item:
                sample["image_cond"] = os.path.join(self.data_root, item["image_cond"])

            samples.append(sample)

        return samples

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        sample = self.samples[idx]

        # Load video frames
        video = self._load_video(sample["video_path"])

        # Load camera poses
        w2c = self._load_camera_pose(sample.get("w2c"))
        intrinsic = self._load_intrinsic(sample.get("intrinsic"))

        # Load or compute action labels
        action = self._load_action(sample.get("action"), w2c)

        # Load image condition for I2V
        image_cond = None
        if self.task == "i2v" and "image_cond" in sample:
            image_cond = self._load_image(sample["image_cond"])

        # Prepare I2V mask
        i2v_mask = self._prepare_i2v_mask(video.shape[0])

        # Select memory window for training
        if self.select_window_out_flag:
            (video, w2c, intrinsic, action, i2v_mask) = self._select_memory_window(video, w2c, intrinsic, action, i2v_mask)

        output = {
            "video": video,
            "caption": sample["caption"],
            "w2c": w2c,
            "intrinsic": intrinsic,
            "action": action,
            "i2v_mask": i2v_mask,
        }

        if image_cond is not None:
            output["image_cond"] = image_cond

        return output

    def _load_video(self, video_path: str) -> torch.Tensor:
        """Load video frames from file."""
        if decord is None:
            raise ImportError("decord is required for video loading")

        vr = decord.VideoReader(video_path)
        total_frames = len(vr)

        # Sample frames
        if total_frames >= self.video_length:
            start_idx = random.randint(0, total_frames - self.video_length)
            frame_indices = list(range(start_idx, start_idx + self.video_length))
        else:
            frame_indices = list(range(total_frames))
            # Pad with last frame
            frame_indices += [total_frames - 1] * (self.video_length - total_frames)

        frames = vr.get_batch(frame_indices)  # [T, H, W, C]
        frames = frames.permute(0, 3, 1, 2).float() / 255.0  # [T, C, H, W]

        # Resize to target resolution
        frames = F.interpolate(frames, size=self.resolution, mode="bilinear", align_corners=False)

        if self.transform is not None:
            frames = self.transform(frames)

        return frames

    def _load_camera_pose(self, w2c_data: Optional[Any]) -> torch.Tensor:
        """Load world-to-camera transformation matrices."""
        if w2c_data is None:
            # Return identity matrices
            return torch.eye(4).unsqueeze(0).repeat(self.video_length, 1, 1)

        if isinstance(w2c_data, str):
            # Load from file
            w2c = np.load(w2c_data)
        elif isinstance(w2c_data, list):
            w2c = np.array(w2c_data)
        else:
            w2c = w2c_data

        w2c = torch.from_numpy(w2c).float()

        # Ensure correct shape [T, 4, 4]
        if w2c.dim() == 2:
            w2c = w2c.unsqueeze(0).repeat(self.video_length, 1, 1)
        elif w2c.shape[0] < self.video_length:
            # Pad with last pose
            pad_size = self.video_length - w2c.shape[0]
            w2c = torch.cat([w2c, w2c[-1:].repeat(pad_size, 1, 1)], dim=0)
        elif w2c.shape[0] > self.video_length:
            w2c = w2c[: self.video_length]

        return w2c

    def _load_intrinsic(self, intrinsic_data: Optional[Any]) -> torch.Tensor:
        """Load camera intrinsic matrices."""
        if intrinsic_data is None:
            # Return default intrinsics
            K = torch.tensor([[500.0, 0.0, self.resolution[1] / 2], [0.0, 500.0, self.resolution[0] / 2], [0.0, 0.0, 1.0]])
            return K.unsqueeze(0).repeat(self.video_length, 1, 1)

        if isinstance(intrinsic_data, str):
            intrinsic = np.load(intrinsic_data)
        elif isinstance(intrinsic_data, list):
            intrinsic = np.array(intrinsic_data)
        else:
            intrinsic = intrinsic_data

        intrinsic = torch.from_numpy(intrinsic).float()

        # Ensure correct shape [T, 3, 3]
        if intrinsic.dim() == 2:
            intrinsic = intrinsic.unsqueeze(0).repeat(self.video_length, 1, 1)
        elif intrinsic.shape[0] < self.video_length:
            pad_size = self.video_length - intrinsic.shape[0]
            intrinsic = torch.cat([intrinsic, intrinsic[-1:].repeat(pad_size, 1, 1)], dim=0)
        elif intrinsic.shape[0] > self.video_length:
            intrinsic = intrinsic[: self.video_length]

        return intrinsic

    def _load_action(self, action_data: Optional[Any], w2c: torch.Tensor) -> torch.Tensor:
        """Load or compute action labels from camera poses."""
        if action_data is not None:
            if isinstance(action_data, str):
                action = np.load(action_data)
            elif isinstance(action_data, list):
                action = np.array(action_data)
            else:
                action = action_data
            action = torch.from_numpy(action).long()
        else:
            # Compute action from camera pose differences
            action = self._compute_action_from_pose(w2c)

        # Ensure correct shape [T]
        if action.shape[0] < self.video_length:
            pad_size = self.video_length - action.shape[0]
            action = torch.cat([action, action[-1:].repeat(pad_size)], dim=0)
        elif action.shape[0] > self.video_length:
            action = action[: self.video_length]

        return action

    def _compute_action_from_pose(self, w2c: torch.Tensor) -> torch.Tensor:
        """
        Compute discrete action labels from camera pose differences.

        Action space: 81 classes (3^4 for forward/backward, left/right,
        up/down, rotation)
        """
        T = w2c.shape[0]
        actions = torch.zeros(T, dtype=torch.long)

        for t in range(1, T):
            # Compute relative transformation
            rel_pose = torch.inverse(w2c[t - 1]) @ w2c[t]

            # Extract translation and rotation
            translation = rel_pose[:3, 3]
            rotation = rel_pose[:3, :3]

            # Quantize to discrete action
            action_idx = self._quantize_action(translation, rotation)
            actions[t] = action_idx

        return actions

    def _quantize_action(self, translation: torch.Tensor, rotation: torch.Tensor) -> int:
        """Quantize continuous motion to discrete action index."""
        # Use configurable thresholds
        trans_thresh = self.action_trans_thresh
        rot_thresh = self.action_rot_thresh

        # Forward/backward (z-axis)
        if translation[2] > trans_thresh:
            fb = 2  # forward
        elif translation[2] < -trans_thresh:
            fb = 0  # backward
        else:
            fb = 1  # stationary

        # Left/right (x-axis)
        if translation[0] > trans_thresh:
            lr = 2  # right
        elif translation[0] < -trans_thresh:
            lr = 0  # left
        else:
            lr = 1  # stationary

        # Up/down (y-axis)
        if translation[1] > trans_thresh:
            ud = 2  # up
        elif translation[1] < -trans_thresh:
            ud = 0  # down
        else:
            ud = 1  # stationary

        # Rotation (simplified to yaw)
        yaw = torch.atan2(rotation[0, 2], rotation[2, 2])
        if yaw > rot_thresh:
            rot = 2  # rotate right
        elif yaw < -rot_thresh:
            rot = 0  # rotate left
        else:
            rot = 1  # no rotation

        # Combine into single index (base-3 encoding)
        action_idx = fb * 27 + lr * 9 + ud * 3 + rot
        return action_idx

    def _load_image(self, image_path: str) -> torch.Tensor:
        """Load conditioning image for I2V."""
        if Image is None:
            raise ImportError("PIL is required for image loading")

        img = Image.open(image_path).convert("RGB")
        img = img.resize((self.resolution[1], self.resolution[0]))
        img = torch.from_numpy(np.array(img)).permute(2, 0, 1).float() / 255.0

        return img

    def _prepare_i2v_mask(self, num_frames: int) -> torch.Tensor:
        """
        Prepare I2V mask for conditional generation.

        For I2V task, the first frame is conditioned (mask=0),
        and remaining frames are generated (mask=1).
        """
        mask = torch.ones(num_frames)
        if self.task == "i2v":
            mask[0] = 0  # First frame is conditioned
        return mask

    def _select_memory_window(
        self,
        video: torch.Tensor,
        w2c: torch.Tensor,
        intrinsic: torch.Tensor,
        action: torch.Tensor,
        i2v_mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, ...]:
        """
        Select a random memory window for training.

        This simulates the AR generation process where we only
        attend to a window of previous frames.
        """
        T = video.shape[0]
        window_size = self.memory_window_size * self.chunk_latent_num

        if T <= window_size:
            return video, w2c, intrinsic, action, i2v_mask

        # Random start position
        start_idx = random.randint(0, T - window_size)
        end_idx = start_idx + window_size

        return (
            video[start_idx:end_idx],
            w2c[start_idx:end_idx],
            intrinsic[start_idx:end_idx],
            action[start_idx:end_idx],
            i2v_mask[start_idx:end_idx],
        )


def collate_fn(batch: List[Dict]) -> Dict[str, torch.Tensor]:
    """Collate function for DataLoader."""
    output = {}

    # Stack tensors
    for key in ["video", "w2c", "intrinsic", "action", "i2v_mask"]:
        if key in batch[0]:
            output[key] = torch.stack([item[key] for item in batch])

    # Handle optional image_cond
    if "image_cond" in batch[0]:
        output["image_cond"] = torch.stack([item["image_cond"] for item in batch])

    # Collect captions
    output["caption"] = [item["caption"] for item in batch]

    return output