"src/vscode:/vscode.git/clone" did not exist on "c180f6015477a15eb9526092dbca8d06b5600e35"
srt_example_llava_v.py 8.61 KB
Newer Older
Yuanhan Zhang's avatar
Yuanhan Zhang committed
1
2
3
4
"""
Usage: python3 srt_example_llava.py
"""

zhyncs's avatar
zhyncs committed
5
import argparse
Yuanhan Zhang's avatar
Yuanhan Zhang committed
6
import csv
zhyncs's avatar
zhyncs committed
7
import os
Yuanhan Zhang's avatar
Yuanhan Zhang committed
8
import time
zhyncs's avatar
zhyncs committed
9
10
11

import sglang as sgl

Yuanhan Zhang's avatar
Yuanhan Zhang committed
12
13
14

@sgl.function
def video_qa(s, num_frames, video_path, question):
zhyncs's avatar
zhyncs committed
15
    s += sgl.user(sgl.video(video_path, num_frames) + question)
Yuanhan Zhang's avatar
Yuanhan Zhang committed
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
    s += sgl.assistant(sgl.gen("answer"))


def single(path, num_frames=16):
    state = video_qa.run(
        num_frames=num_frames,
        video_path=path,
        question="Please provide a detailed description of the video, focusing on the main subjects, their actions, the background scenes",
        temperature=0.0,
        max_new_tokens=1024,
    )
    print(state["answer"], "\n")


def split_into_chunks(lst, num_chunks):
    """Split a list into a specified number of chunks."""
    # Calculate the chunk size using integer division. Note that this may drop some items if not evenly divisible.
    chunk_size = len(lst) // num_chunks

    if chunk_size == 0:
        chunk_size = len(lst)
    # Use list comprehension to generate chunks. The last chunk will take any remainder if the list size isn't evenly divisible.
zhyncs's avatar
zhyncs committed
38
    chunks = [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
Yuanhan Zhang's avatar
Yuanhan Zhang committed
39
40
41
42
43
44
45
    # Ensure we have exactly num_chunks chunks, even if some are empty
    chunks.extend([[] for _ in range(num_chunks - len(chunks))])
    return chunks


def save_batch_results(batch_video_files, states, cur_chunk, batch_idx, save_dir):
    csv_filename = f"{save_dir}/chunk_{cur_chunk}_batch_{batch_idx}.csv"
zhyncs's avatar
zhyncs committed
46
    with open(csv_filename, "w", newline="") as csvfile:
Yuanhan Zhang's avatar
Yuanhan Zhang committed
47
        writer = csv.writer(csvfile)
zhyncs's avatar
zhyncs committed
48
        writer.writerow(["video_name", "answer"])
Yuanhan Zhang's avatar
Yuanhan Zhang committed
49
50
51
52
        for video_path, state in zip(batch_video_files, states):
            video_name = os.path.basename(video_path)
            writer.writerow([video_name, state["answer"]])

zhyncs's avatar
zhyncs committed
53

Yuanhan Zhang's avatar
Yuanhan Zhang committed
54
55
def compile_and_cleanup_final_results(cur_chunk, num_batches, save_dir):
    final_csv_filename = f"{save_dir}/final_results_chunk_{cur_chunk}.csv"
zhyncs's avatar
zhyncs committed
56
    with open(final_csv_filename, "w", newline="") as final_csvfile:
Yuanhan Zhang's avatar
Yuanhan Zhang committed
57
        writer = csv.writer(final_csvfile)
zhyncs's avatar
zhyncs committed
58
        writer.writerow(["video_name", "answer"])
Yuanhan Zhang's avatar
Yuanhan Zhang committed
59
60
        for batch_idx in range(num_batches):
            batch_csv_filename = f"{save_dir}/chunk_{cur_chunk}_batch_{batch_idx}.csv"
zhyncs's avatar
zhyncs committed
61
            with open(batch_csv_filename, "r") as batch_csvfile:
Yuanhan Zhang's avatar
Yuanhan Zhang committed
62
63
64
65
66
67
                reader = csv.reader(batch_csvfile)
                next(reader)  # Skip header row
                for row in reader:
                    writer.writerow(row)
            os.remove(batch_csv_filename)

zhyncs's avatar
zhyncs committed
68

Yuanhan Zhang's avatar
Yuanhan Zhang committed
69
70
71
72
73
def find_video_files(video_dir):
    # Check if the video_dir is actually a file
    if os.path.isfile(video_dir):
        # If it's a file, return it as a single-element list
        return [video_dir]
zhyncs's avatar
zhyncs committed
74

Yuanhan Zhang's avatar
Yuanhan Zhang committed
75
76
77
78
    # Original logic to find video files in a directory
    video_files = []
    for root, dirs, files in os.walk(video_dir):
        for file in files:
zhyncs's avatar
zhyncs committed
79
            if file.endswith((".mp4", ".avi", ".mov")):
Yuanhan Zhang's avatar
Yuanhan Zhang committed
80
81
82
                video_files.append(os.path.join(root, file))
    return video_files

zhyncs's avatar
zhyncs committed
83

Yuanhan Zhang's avatar
Yuanhan Zhang committed
84
85
86
87
88
89
def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size=64):
    video_files = find_video_files(video_dir)
    chunked_video_files = split_into_chunks(video_files, num_chunks)[cur_chunk]
    num_batches = 0

    for i in range(0, len(chunked_video_files), batch_size):
zhyncs's avatar
zhyncs committed
90
        batch_video_files = chunked_video_files[i : i + batch_size]
Yuanhan Zhang's avatar
Yuanhan Zhang committed
91
92
93
94
95
        print(f"Processing batch of {len(batch_video_files)} video(s)...")

        if not batch_video_files:
            print("No video files found in the specified directory.")
            return
zhyncs's avatar
zhyncs committed
96

Yuanhan Zhang's avatar
Yuanhan Zhang committed
97
        batch_input = [
zhyncs's avatar
zhyncs committed
98
            {
Yuanhan Zhang's avatar
Yuanhan Zhang committed
99
100
101
                "num_frames": num_frames,
                "video_path": video_path,
                "question": "Please provide a detailed description of the video, focusing on the main subjects, their actions, the background scenes.",
zhyncs's avatar
zhyncs committed
102
103
            }
            for video_path in batch_video_files
Yuanhan Zhang's avatar
Yuanhan Zhang committed
104
105
106
107
108
109
        ]

        start_time = time.time()
        states = video_qa.run_batch(batch_input, max_new_tokens=512, temperature=0.2)
        total_time = time.time() - start_time
        average_time = total_time / len(batch_video_files)
zhyncs's avatar
zhyncs committed
110
111
112
        print(
            f"Number of videos in batch: {len(batch_video_files)}. Average processing time per video: {average_time:.2f} seconds. Total time for this batch: {total_time:.2f} seconds"
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
113
114
115
116
117
118
119
120
121
122

        save_batch_results(batch_video_files, states, cur_chunk, num_batches, save_dir)
        num_batches += 1

    compile_and_cleanup_final_results(cur_chunk, num_batches, save_dir)


if __name__ == "__main__":

    # Create the parser
zhyncs's avatar
zhyncs committed
123
124
125
    parser = argparse.ArgumentParser(
        description="Run video processing with specified port."
    )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
126
127

    # Add an argument for the port
zhyncs's avatar
zhyncs committed
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
    parser.add_argument(
        "--port",
        type=int,
        default=30000,
        help="The master port for distributed serving.",
    )
    parser.add_argument(
        "--chunk-idx", type=int, default=0, help="The index of the chunk to process."
    )
    parser.add_argument(
        "--num-chunks", type=int, default=8, help="The number of chunks to process."
    )
    parser.add_argument(
        "--save-dir",
        type=str,
        default="./work_dirs/llava_video",
        help="The directory to save the processed video files.",
    )
    parser.add_argument(
        "--video-dir",
        type=str,
        default="./videos/Q98Z4OTh8RwmDonc.mp4",
        help="The directory or path for the processed video files.",
    )
    parser.add_argument(
        "--model-path",
        type=str,
        default="lmms-lab/LLaVA-NeXT-Video-7B",
        help="The model path for the video processing.",
    )
    parser.add_argument(
        "--num-frames",
        type=int,
        default=16,
        help="The number of frames to process in each video.",
    )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    parser.add_argument("--mm_spatial_pool_stride", type=int, default=2)

    # Parse the arguments
    args = parser.parse_args()

    cur_port = args.port

    cur_chunk = args.chunk_idx

    num_chunks = args.num_chunks

    num_frames = args.num_frames

    if "34b" in args.model_path.lower():
        tokenizer_path = "liuhaotian/llava-v1.6-34b-tokenizer"
    elif "7b" in args.model_path.lower():
        tokenizer_path = "llava-hf/llava-1.5-7b-hf"
    else:
        print("Invalid model path. Please specify a valid model path.")
        exit()

    model_overide_args = {}

    model_overide_args["mm_spatial_pool_stride"] = args.mm_spatial_pool_stride
    model_overide_args["architectures"] = ["LlavaVidForCausalLM"]
    model_overide_args["num_frames"] = args.num_frames
    model_overide_args["model_type"] = "llava"

    if "34b" in args.model_path.lower():
        model_overide_args["image_token_index"] = 64002

    if args.num_frames == 32:
        model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"}
        model_overide_args["max_sequence_length"] = 4096 * 2
        model_overide_args["tokenizer_model_max_length"] = 4096 * 2
    elif args.num_frames < 32:
        pass
    else:
zhyncs's avatar
zhyncs committed
202
203
204
        print(
            "The maximum number of frames to process is 32. Please specify a valid number of frames."
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
205
206
207
        exit()

    runtime = sgl.Runtime(
zhyncs's avatar
zhyncs committed
208
        model_path=args.model_path,  # "liuhaotian/llava-v1.6-vicuna-7b",
Yuanhan Zhang's avatar
Yuanhan Zhang committed
209
210
        tokenizer_path=tokenizer_path,
        port=cur_port,
zhyncs's avatar
zhyncs committed
211
        additional_ports=[cur_port + 1, cur_port + 2, cur_port + 3, cur_port + 4],
Yuanhan Zhang's avatar
Yuanhan Zhang committed
212
        model_overide_args=model_overide_args,
zhyncs's avatar
zhyncs committed
213
        tp_size=1,
Yuanhan Zhang's avatar
Yuanhan Zhang committed
214
215
216
217
218
219
220
221
222
223
224
    )
    sgl.set_default_backend(runtime)
    print(f"chat template: {runtime.endpoint.chat_template.name}")

    # Run a single request
    # try:
    print("\n========== single ==========\n")
    root = args.video_dir
    if os.path.isfile(root):
        video_files = [root]
    else:
zhyncs's avatar
zhyncs committed
225
226
227
228
229
        video_files = [
            os.path.join(root, f)
            for f in os.listdir(root)
            if f.endswith((".mp4", ".avi", ".mov"))
        ]  # Add more extensions if needed
Yuanhan Zhang's avatar
Yuanhan Zhang committed
230
231
232
233
234
235
    start_time = time.time()  # Start time for processing a single video
    for cur_video in video_files[:1]:
        print(cur_video)
        single(cur_video, num_frames)
    end_time = time.time()  # End time for processing a single video
    total_time = end_time - start_time
zhyncs's avatar
zhyncs committed
236
237
238
    average_time = total_time / len(
        video_files
    )  # Calculate the average processing time
Yuanhan Zhang's avatar
Yuanhan Zhang committed
239
240
241
242
243
244
245
246
247
248
249
    print(f"Average processing time per video: {average_time:.2f} seconds")
    runtime.shutdown()
    # except Exception as e:
    #     print(e)
    runtime.shutdown()

    # # # Run a batch of requests
    # print("\n========== batch ==========\n")
    # if not os.path.exists(args.save_dir):
    #     os.makedirs(args.save_dir)
    # batch(args.video_dir,args.save_dir,cur_chunk, num_chunks, num_frames, num_chunks)
zhyncs's avatar
zhyncs committed
250
    # runtime.shutdown()