srt_example_llava_v.py 8.64 KB
Newer Older
Yuanhan Zhang's avatar
Yuanhan Zhang committed
1
"""
Ying Sheng's avatar
Ying Sheng committed
2
3
4
Usage:
pip install opencv-python-headless
python3 srt_example_llava.py
Yuanhan Zhang's avatar
Yuanhan Zhang committed
5
6
"""

zhyncs's avatar
zhyncs committed
7
import argparse
Yuanhan Zhang's avatar
Yuanhan Zhang committed
8
import csv
zhyncs's avatar
zhyncs committed
9
import os
Yuanhan Zhang's avatar
Yuanhan Zhang committed
10
import time
zhyncs's avatar
zhyncs committed
11
12
13

import sglang as sgl

Yuanhan Zhang's avatar
Yuanhan Zhang committed
14
15
16

@sgl.function
def video_qa(s, num_frames, video_path, question):
zhyncs's avatar
zhyncs committed
17
    s += sgl.user(sgl.video(video_path, num_frames) + question)
Yuanhan Zhang's avatar
Yuanhan Zhang committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    s += sgl.assistant(sgl.gen("answer"))


def single(path, num_frames=16):
    state = video_qa.run(
        num_frames=num_frames,
        video_path=path,
        question="Please provide a detailed description of the video, focusing on the main subjects, their actions, the background scenes",
        temperature=0.0,
        max_new_tokens=1024,
    )
    print(state["answer"], "\n")


def split_into_chunks(lst, num_chunks):
    """Split a list into a specified number of chunks."""
    # Calculate the chunk size using integer division. Note that this may drop some items if not evenly divisible.
    chunk_size = len(lst) // num_chunks

    if chunk_size == 0:
        chunk_size = len(lst)
    # Use list comprehension to generate chunks. The last chunk will take any remainder if the list size isn't evenly divisible.
zhyncs's avatar
zhyncs committed
40
    chunks = [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
Yuanhan Zhang's avatar
Yuanhan Zhang committed
41
42
43
44
45
46
47
    # Ensure we have exactly num_chunks chunks, even if some are empty
    chunks.extend([[] for _ in range(num_chunks - len(chunks))])
    return chunks


def save_batch_results(batch_video_files, states, cur_chunk, batch_idx, save_dir):
    csv_filename = f"{save_dir}/chunk_{cur_chunk}_batch_{batch_idx}.csv"
zhyncs's avatar
zhyncs committed
48
    with open(csv_filename, "w", newline="") as csvfile:
Yuanhan Zhang's avatar
Yuanhan Zhang committed
49
        writer = csv.writer(csvfile)
zhyncs's avatar
zhyncs committed
50
        writer.writerow(["video_name", "answer"])
Yuanhan Zhang's avatar
Yuanhan Zhang committed
51
52
53
54
        for video_path, state in zip(batch_video_files, states):
            video_name = os.path.basename(video_path)
            writer.writerow([video_name, state["answer"]])

zhyncs's avatar
zhyncs committed
55

Yuanhan Zhang's avatar
Yuanhan Zhang committed
56
57
def compile_and_cleanup_final_results(cur_chunk, num_batches, save_dir):
    final_csv_filename = f"{save_dir}/final_results_chunk_{cur_chunk}.csv"
zhyncs's avatar
zhyncs committed
58
    with open(final_csv_filename, "w", newline="") as final_csvfile:
Yuanhan Zhang's avatar
Yuanhan Zhang committed
59
        writer = csv.writer(final_csvfile)
zhyncs's avatar
zhyncs committed
60
        writer.writerow(["video_name", "answer"])
Yuanhan Zhang's avatar
Yuanhan Zhang committed
61
62
        for batch_idx in range(num_batches):
            batch_csv_filename = f"{save_dir}/chunk_{cur_chunk}_batch_{batch_idx}.csv"
zhyncs's avatar
zhyncs committed
63
            with open(batch_csv_filename, "r") as batch_csvfile:
Yuanhan Zhang's avatar
Yuanhan Zhang committed
64
65
66
67
68
69
                reader = csv.reader(batch_csvfile)
                next(reader)  # Skip header row
                for row in reader:
                    writer.writerow(row)
            os.remove(batch_csv_filename)

zhyncs's avatar
zhyncs committed
70

Yuanhan Zhang's avatar
Yuanhan Zhang committed
71
72
73
74
75
def find_video_files(video_dir):
    # Check if the video_dir is actually a file
    if os.path.isfile(video_dir):
        # If it's a file, return it as a single-element list
        return [video_dir]
zhyncs's avatar
zhyncs committed
76

Yuanhan Zhang's avatar
Yuanhan Zhang committed
77
78
79
80
    # Original logic to find video files in a directory
    video_files = []
    for root, dirs, files in os.walk(video_dir):
        for file in files:
zhyncs's avatar
zhyncs committed
81
            if file.endswith((".mp4", ".avi", ".mov")):
Yuanhan Zhang's avatar
Yuanhan Zhang committed
82
83
84
                video_files.append(os.path.join(root, file))
    return video_files

zhyncs's avatar
zhyncs committed
85

Yuanhan Zhang's avatar
Yuanhan Zhang committed
86
87
88
89
90
91
def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size=64):
    video_files = find_video_files(video_dir)
    chunked_video_files = split_into_chunks(video_files, num_chunks)[cur_chunk]
    num_batches = 0

    for i in range(0, len(chunked_video_files), batch_size):
zhyncs's avatar
zhyncs committed
92
        batch_video_files = chunked_video_files[i : i + batch_size]
Yuanhan Zhang's avatar
Yuanhan Zhang committed
93
94
95
96
97
        print(f"Processing batch of {len(batch_video_files)} video(s)...")

        if not batch_video_files:
            print("No video files found in the specified directory.")
            return
zhyncs's avatar
zhyncs committed
98

Yuanhan Zhang's avatar
Yuanhan Zhang committed
99
        batch_input = [
zhyncs's avatar
zhyncs committed
100
            {
Yuanhan Zhang's avatar
Yuanhan Zhang committed
101
102
103
                "num_frames": num_frames,
                "video_path": video_path,
                "question": "Please provide a detailed description of the video, focusing on the main subjects, their actions, the background scenes.",
zhyncs's avatar
zhyncs committed
104
105
            }
            for video_path in batch_video_files
Yuanhan Zhang's avatar
Yuanhan Zhang committed
106
107
108
109
110
111
        ]

        start_time = time.time()
        states = video_qa.run_batch(batch_input, max_new_tokens=512, temperature=0.2)
        total_time = time.time() - start_time
        average_time = total_time / len(batch_video_files)
zhyncs's avatar
zhyncs committed
112
113
114
        print(
            f"Number of videos in batch: {len(batch_video_files)}. Average processing time per video: {average_time:.2f} seconds. Total time for this batch: {total_time:.2f} seconds"
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
115
116
117
118
119
120
121
122
123
124

        save_batch_results(batch_video_files, states, cur_chunk, num_batches, save_dir)
        num_batches += 1

    compile_and_cleanup_final_results(cur_chunk, num_batches, save_dir)


if __name__ == "__main__":

    # Create the parser
zhyncs's avatar
zhyncs committed
125
126
127
    parser = argparse.ArgumentParser(
        description="Run video processing with specified port."
    )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
128
129

    # Add an argument for the port
zhyncs's avatar
zhyncs committed
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
    parser.add_argument(
        "--port",
        type=int,
        default=30000,
        help="The master port for distributed serving.",
    )
    parser.add_argument(
        "--chunk-idx", type=int, default=0, help="The index of the chunk to process."
    )
    parser.add_argument(
        "--num-chunks", type=int, default=8, help="The number of chunks to process."
    )
    parser.add_argument(
        "--save-dir",
        type=str,
        default="./work_dirs/llava_video",
        help="The directory to save the processed video files.",
    )
    parser.add_argument(
        "--video-dir",
        type=str,
        default="./videos/Q98Z4OTh8RwmDonc.mp4",
        help="The directory or path for the processed video files.",
    )
    parser.add_argument(
        "--model-path",
        type=str,
        default="lmms-lab/LLaVA-NeXT-Video-7B",
        help="The model path for the video processing.",
    )
    parser.add_argument(
        "--num-frames",
        type=int,
        default=16,
        help="The number of frames to process in each video.",
    )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
    parser.add_argument("--mm_spatial_pool_stride", type=int, default=2)

    # Parse the arguments
    args = parser.parse_args()

    cur_port = args.port

    cur_chunk = args.chunk_idx

    num_chunks = args.num_chunks

    num_frames = args.num_frames

    if "34b" in args.model_path.lower():
        tokenizer_path = "liuhaotian/llava-v1.6-34b-tokenizer"
    elif "7b" in args.model_path.lower():
        tokenizer_path = "llava-hf/llava-1.5-7b-hf"
    else:
        print("Invalid model path. Please specify a valid model path.")
        exit()

    model_overide_args = {}

    model_overide_args["mm_spatial_pool_stride"] = args.mm_spatial_pool_stride
    model_overide_args["architectures"] = ["LlavaVidForCausalLM"]
    model_overide_args["num_frames"] = args.num_frames
    model_overide_args["model_type"] = "llava"

    if "34b" in args.model_path.lower():
        model_overide_args["image_token_index"] = 64002

    if args.num_frames == 32:
        model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"}
        model_overide_args["max_sequence_length"] = 4096 * 2
        model_overide_args["tokenizer_model_max_length"] = 4096 * 2
    elif args.num_frames < 32:
        pass
    else:
zhyncs's avatar
zhyncs committed
204
205
206
        print(
            "The maximum number of frames to process is 32. Please specify a valid number of frames."
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
207
208
209
        exit()

    runtime = sgl.Runtime(
zhyncs's avatar
zhyncs committed
210
        model_path=args.model_path,  # "liuhaotian/llava-v1.6-vicuna-7b",
Yuanhan Zhang's avatar
Yuanhan Zhang committed
211
212
        tokenizer_path=tokenizer_path,
        port=cur_port,
zhyncs's avatar
zhyncs committed
213
        additional_ports=[cur_port + 1, cur_port + 2, cur_port + 3, cur_port + 4],
Yuanhan Zhang's avatar
Yuanhan Zhang committed
214
        model_overide_args=model_overide_args,
zhyncs's avatar
zhyncs committed
215
        tp_size=1,
Yuanhan Zhang's avatar
Yuanhan Zhang committed
216
217
218
219
220
221
222
223
224
225
226
    )
    sgl.set_default_backend(runtime)
    print(f"chat template: {runtime.endpoint.chat_template.name}")

    # Run a single request
    # try:
    print("\n========== single ==========\n")
    root = args.video_dir
    if os.path.isfile(root):
        video_files = [root]
    else:
zhyncs's avatar
zhyncs committed
227
228
229
230
231
        video_files = [
            os.path.join(root, f)
            for f in os.listdir(root)
            if f.endswith((".mp4", ".avi", ".mov"))
        ]  # Add more extensions if needed
Yuanhan Zhang's avatar
Yuanhan Zhang committed
232
233
234
235
236
237
    start_time = time.time()  # Start time for processing a single video
    for cur_video in video_files[:1]:
        print(cur_video)
        single(cur_video, num_frames)
    end_time = time.time()  # End time for processing a single video
    total_time = end_time - start_time
zhyncs's avatar
zhyncs committed
238
239
240
    average_time = total_time / len(
        video_files
    )  # Calculate the average processing time
Yuanhan Zhang's avatar
Yuanhan Zhang committed
241
242
243
244
245
246
247
248
249
250
251
    print(f"Average processing time per video: {average_time:.2f} seconds")
    runtime.shutdown()
    # except Exception as e:
    #     print(e)
    runtime.shutdown()

    # # # Run a batch of requests
    # print("\n========== batch ==========\n")
    # if not os.path.exists(args.save_dir):
    #     os.makedirs(args.save_dir)
    # batch(args.video_dir,args.save_dir,cur_chunk, num_chunks, num_frames, num_chunks)
zhyncs's avatar
zhyncs committed
252
    # runtime.shutdown()