vid_recon.py 4.15 KB
Newer Older
gushiqiao's avatar
gushiqiao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import argparse

import cv2
import torch
from loguru import logger

from lightx2v.models.video_encoders.hf.wan.vae import WanVAE
from lightx2v.models.video_encoders.hf.wan.vae_2_2 import Wan2_2_VAE
from lightx2v.models.video_encoders.hf.wan.vae_tiny import Wan2_2_VAE_tiny, WanVAE_tiny


class VideoTensorReader:
    def __init__(self, video_file_path):
        self.cap = cv2.VideoCapture(video_file_path)
        assert self.cap.isOpened(), f"Could not load {video_file_path}"
        self.fps = self.cap.get(cv2.CAP_PROP_FPS)

    def __iter__(self):
        return self

    def __next__(self):
        ret, frame = self.cap.read()
        if not ret:
            self.cap.release()
            raise StopIteration  # End of video or error
        return torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)).permute(2, 0, 1)  # BGR HWC -> RGB CHW


class VideoTensorWriter:
    def __init__(self, video_file_path, width_height, fps=30):
        self.writer = cv2.VideoWriter(video_file_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, width_height)
        assert self.writer.isOpened(), f"Could not create writer for {video_file_path}"

    def write(self, frame_tensor):
        assert frame_tensor.ndim == 3 and frame_tensor.shape[0] == 3, f"{frame_tensor.shape}??"
        self.writer.write(cv2.cvtColor(frame_tensor.permute(1, 2, 0).numpy(), cv2.COLOR_RGB2BGR))  # RGB CHW -> BGR HWC

    def __del__(self):
        if hasattr(self, "writer"):
            self.writer.release()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Encode and decode videos using the TaeHV model for reconstruction")
    parser.add_argument("video_paths", nargs="+", help="Paths to input video files (multiple allowed)")
    parser.add_argument("--checkpoint", "-c", help=f"Path to the model checkpoint file")
    parser.add_argument("--device", "-d", default="cuda", help=f'Computing device (e.g., "cuda", "mps", "cpu"; default: auto-detect available device)')
    parser.add_argument("--dtype", default="bfloat16", choices=["bfloat16", "float32"], help="Data type for model computation (default: bfloat16)")
    parser.add_argument("--model_type", choices=["taew2_1", "taew2_2", "vaew2_1", "vaew2_2"], required=True, help="Type of the model to use (choices: taew2_1, taew2_2)")
    parser.add_argument("--use_lightvae", default=False, action="store_true")

    args = parser.parse_args()
    if args.use_lightvae:
        assert args.model_type in ["vaew2_1"]

    if args.device:
        dev = torch.device(args.device)
    else:
        dev = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

    dtype_map = {"bfloat16": torch.bfloat16, "float32": torch.float32}
    model_map = {"taew2_1": WanVAE_tiny, "taew2_2": Wan2_2_VAE_tiny, "vaew2_1": WanVAE, "vaew2_2": Wan2_2_VAE}

    dtype = dtype_map[args.dtype]

    model_args = {"vae_pth": args.checkpoint, "dtype": dtype, "device": dev}
    if args.model_type in "vaew2_1":
        model_args.update({"use_lightvae": args.use_lightvae})

    model = model_map[args.model_type](**model_args)
    if args.model_type.startswith("tae"):
        model = model_map[args.model_type](**model_args).to(dev)

    # Process each input video
    for idx, video_path in enumerate(args.video_paths):
        logger.info(f"Processing video {video_path}...")
        # Read video
        video_in = VideoTensorReader(video_path)
        video = torch.stack(list(video_in), 0)[None]  # Add batch dimension
        vid_dev = video.to(dev, dtype).div_(255.0)  # Normalize to [0,1]
        # Encode
        vid_enc = model.encode_video(vid_dev)
        if isinstance(vid_enc, tuple):
            vid_enc = vid_enc[0]
        # Decode
        vid_dec = model.decode_video(vid_enc)
        # Save reconstructed video
        video_out_path = f"{video_path}.reconstructed_{idx}.mp4"
        frame_size = (vid_dec.shape[-1], vid_dec.shape[-2])
        fps = int(round(video_in.fps))
        video_out = VideoTensorWriter(video_out_path, frame_size, fps)
        for frame in vid_dec.clamp_(0, 1).mul_(255).round_().byte().cpu()[0]:
            video_out.write(frame)
        logger.info(f"  Reconstructed video saved to {video_out_path}")