vis_proto.py 4.39 KB
Newer Older
zzg_666's avatar
zzg_666 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# -*- coding: utf-8 -*-
# Copyright 2025 BAAI. and/or its affiliates.
# SPDX-License-Identifier: Apache-2.0

import os
import sys
import argparse
import numpy as np
from PIL import Image
import io

project_root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
if project_root not in sys.path:
    sys.path.insert(0, project_root)
from src.proto import emu_pb as story_pb
from src.utils.video_utils import plot_string, save_image_list_to_video

def main():
    parser = argparse.ArgumentParser(description='Visualize protobuf story files')
    parser.add_argument('--input', '-i', required=True, help='Input protobuf file path')
    parser.add_argument('--output', '-o', required=True, help='Output directory path')
    parser.add_argument('--video', action='store_true', help='Generate video from protobuf content')
    parser.add_argument('--fps', type=int, default=1, help='Frames per second for video (default: 1)')
    args = parser.parse_args()
    
    input_path = args.input
    output_path = args.output
    
    os.makedirs(output_path, exist_ok=True)
    
    with open(input_path, 'rb') as f:
        story = story_pb.Story()
        story.ParseFromString(f.read())

    with open(f"{output_path}/000_question.txt", 'w') as f:
        print(story.question, file=f)

    if story.summary and story.summary.strip():
        with open(f"{output_path}/000_global_cot.txt", 'w') as f:
            print(story.summary, file=f)

    idx = 1
    
    if len(story.reference_images) > 0:
        for i in range(len(story.reference_images)):
            with open(f"{output_path}/{i:03d}_reference_image.png", 'wb') as f:
                f.write(story.reference_images[i].image.image_data)
        idx = len(story.reference_images)
        
    for c in story.clips:
        for s in c.segments:
            with open(f"{output_path}/{idx:03d}_text.txt", 'w') as f:
                print(s.asr, file=f)
            for im_idx, im in enumerate(s.images):
                with open(f"{output_path}/{idx:03d}_{im_idx:02d}_image.png", 'wb') as f:
                    f.write(im.image.image_data)
                if im.chain_of_thought and im.chain_of_thought.strip():
                    with open(f"{output_path}/{idx:03d}_{im_idx:02d}_image_cot.txt", 'w') as f:
                        print(im.chain_of_thought, file=f)
            idx += 1
    
    if args.video:
        video_images = []
        target_size = None
        
        for ref_img_data in story.reference_images:
            img = Image.open(io.BytesIO(ref_img_data.image.image_data))
            img = img.convert('RGB')
            if target_size is None:
                target_size = img.size
        
        for c in story.clips:
            for s in c.segments:
                for im in s.images:
                    img = Image.open(io.BytesIO(im.image.image_data))
                    img = img.convert('RGB')
                    if target_size is None:
                        target_size = img.size
        
        if target_size is None:
            target_size = (512, 512)
        
        if story.question and story.question.strip():
            question_img = plot_string(story.question, image_size=(target_size[0], target_size[1]))
            video_images.append(question_img)
        
        for img_array in story.reference_images:
            img = Image.open(io.BytesIO(img_array.image.image_data))
            img = img.convert('RGB')
            if img.size != target_size:
                img = img.resize(target_size, Image.Resampling.LANCZOS)
            video_images.append(np.array(img))
        
        for c in story.clips:
            for s in c.segments:
                if s.asr and s.asr.strip():
                    asr_img = plot_string(s.asr, image_size=(target_size[0], target_size[1]))
                    video_images.append(asr_img)
                
                for im in s.images:
                    img = Image.open(io.BytesIO(im.image.image_data))
                    img = img.convert('RGB')
                    if img.size != target_size:
                        img = img.resize(target_size, Image.Resampling.LANCZOS)
                    video_images.append(np.array(img))
        
        if video_images:
            video_path = f"{output_path}/video.mp4"
            save_image_list_to_video(video_images, video_path, fps=args.fps, quality='high')
            print(f"Video saved to: {video_path}")

if __name__ == "__main__":
    main()