srt_example_llava_v.py 9.12 KB
Newer Older
Yuanhan Zhang's avatar
Yuanhan Zhang committed
1
"""
Ying Sheng's avatar
Ying Sheng committed
2
3
4
Usage:
pip install opencv-python-headless
python3 srt_example_llava.py
Yuanhan Zhang's avatar
Yuanhan Zhang committed
5
6
"""

zhyncs's avatar
zhyncs committed
7
import argparse
Yuanhan Zhang's avatar
Yuanhan Zhang committed
8
import csv
zhyncs's avatar
zhyncs committed
9
import os
Yuanhan Zhang's avatar
Yuanhan Zhang committed
10
import time
zhyncs's avatar
zhyncs committed
11
12
13

import sglang as sgl

Yuanhan Zhang's avatar
Yuanhan Zhang committed
14
15
16

@sgl.function
def video_qa(s, num_frames, video_path, question):
zhyncs's avatar
zhyncs committed
17
    s += sgl.user(sgl.video(video_path, num_frames) + question)
Yuanhan Zhang's avatar
Yuanhan Zhang committed
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    s += sgl.assistant(sgl.gen("answer"))


def single(path, num_frames=16):
    state = video_qa.run(
        num_frames=num_frames,
        video_path=path,
        question="Please provide a detailed description of the video, focusing on the main subjects, their actions, the background scenes",
        temperature=0.0,
        max_new_tokens=1024,
    )
    print(state["answer"], "\n")


def split_into_chunks(lst, num_chunks):
    """Split a list into a specified number of chunks."""
    # Calculate the chunk size using integer division. Note that this may drop some items if not evenly divisible.
    chunk_size = len(lst) // num_chunks

    if chunk_size == 0:
        chunk_size = len(lst)
    # Use list comprehension to generate chunks. The last chunk will take any remainder if the list size isn't evenly divisible.
zhyncs's avatar
zhyncs committed
40
    chunks = [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
Yuanhan Zhang's avatar
Yuanhan Zhang committed
41
42
43
44
45
46
47
    # Ensure we have exactly num_chunks chunks, even if some are empty
    chunks.extend([[] for _ in range(num_chunks - len(chunks))])
    return chunks


def save_batch_results(batch_video_files, states, cur_chunk, batch_idx, save_dir):
    csv_filename = f"{save_dir}/chunk_{cur_chunk}_batch_{batch_idx}.csv"
zhyncs's avatar
zhyncs committed
48
    with open(csv_filename, "w", newline="") as csvfile:
Yuanhan Zhang's avatar
Yuanhan Zhang committed
49
        writer = csv.writer(csvfile)
zhyncs's avatar
zhyncs committed
50
        writer.writerow(["video_name", "answer"])
Yuanhan Zhang's avatar
Yuanhan Zhang committed
51
52
53
54
        for video_path, state in zip(batch_video_files, states):
            video_name = os.path.basename(video_path)
            writer.writerow([video_name, state["answer"]])

zhyncs's avatar
zhyncs committed
55

Yuanhan Zhang's avatar
Yuanhan Zhang committed
56
57
def compile_and_cleanup_final_results(cur_chunk, num_batches, save_dir):
    final_csv_filename = f"{save_dir}/final_results_chunk_{cur_chunk}.csv"
zhyncs's avatar
zhyncs committed
58
    with open(final_csv_filename, "w", newline="") as final_csvfile:
Yuanhan Zhang's avatar
Yuanhan Zhang committed
59
        writer = csv.writer(final_csvfile)
zhyncs's avatar
zhyncs committed
60
        writer.writerow(["video_name", "answer"])
Yuanhan Zhang's avatar
Yuanhan Zhang committed
61
62
        for batch_idx in range(num_batches):
            batch_csv_filename = f"{save_dir}/chunk_{cur_chunk}_batch_{batch_idx}.csv"
zhyncs's avatar
zhyncs committed
63
            with open(batch_csv_filename, "r") as batch_csvfile:
Yuanhan Zhang's avatar
Yuanhan Zhang committed
64
65
66
67
68
69
                reader = csv.reader(batch_csvfile)
                next(reader)  # Skip header row
                for row in reader:
                    writer.writerow(row)
            os.remove(batch_csv_filename)

zhyncs's avatar
zhyncs committed
70

Yuanhan Zhang's avatar
Yuanhan Zhang committed
71
72
73
74
75
def find_video_files(video_dir):
    # Check if the video_dir is actually a file
    if os.path.isfile(video_dir):
        # If it's a file, return it as a single-element list
        return [video_dir]
zhyncs's avatar
zhyncs committed
76

Yuanhan Zhang's avatar
Yuanhan Zhang committed
77
78
79
80
    # Original logic to find video files in a directory
    video_files = []
    for root, dirs, files in os.walk(video_dir):
        for file in files:
zhyncs's avatar
zhyncs committed
81
            if file.endswith((".mp4", ".avi", ".mov")):
Yuanhan Zhang's avatar
Yuanhan Zhang committed
82
83
84
                video_files.append(os.path.join(root, file))
    return video_files

zhyncs's avatar
zhyncs committed
85

Yuanhan Zhang's avatar
Yuanhan Zhang committed
86
87
88
89
90
91
def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size=64):
    video_files = find_video_files(video_dir)
    chunked_video_files = split_into_chunks(video_files, num_chunks)[cur_chunk]
    num_batches = 0

    for i in range(0, len(chunked_video_files), batch_size):
zhyncs's avatar
zhyncs committed
92
        batch_video_files = chunked_video_files[i : i + batch_size]
Yuanhan Zhang's avatar
Yuanhan Zhang committed
93
94
95
96
97
        print(f"Processing batch of {len(batch_video_files)} video(s)...")

        if not batch_video_files:
            print("No video files found in the specified directory.")
            return
zhyncs's avatar
zhyncs committed
98

Yuanhan Zhang's avatar
Yuanhan Zhang committed
99
        batch_input = [
zhyncs's avatar
zhyncs committed
100
            {
Yuanhan Zhang's avatar
Yuanhan Zhang committed
101
102
103
                "num_frames": num_frames,
                "video_path": video_path,
                "question": "Please provide a detailed description of the video, focusing on the main subjects, their actions, the background scenes.",
zhyncs's avatar
zhyncs committed
104
105
            }
            for video_path in batch_video_files
Yuanhan Zhang's avatar
Yuanhan Zhang committed
106
107
108
109
110
111
        ]

        start_time = time.time()
        states = video_qa.run_batch(batch_input, max_new_tokens=512, temperature=0.2)
        total_time = time.time() - start_time
        average_time = total_time / len(batch_video_files)
zhyncs's avatar
zhyncs committed
112
113
114
        print(
            f"Number of videos in batch: {len(batch_video_files)}. Average processing time per video: {average_time:.2f} seconds. Total time for this batch: {total_time:.2f} seconds"
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
115
116
117
118
119
120
121
122
123

        save_batch_results(batch_video_files, states, cur_chunk, num_batches, save_dir)
        num_batches += 1

    compile_and_cleanup_final_results(cur_chunk, num_batches, save_dir)


if __name__ == "__main__":

124
125
126
127
128
129
130
131
132
133
134
135
136
137
    url = "https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4"

    cache_dir = os.path.expanduser("~/.cache")
    file_path = os.path.join(cache_dir, "jobs.mp4")

    os.makedirs(cache_dir, exist_ok=True)

    response = requests.get(url)
    response.raise_for_status()  # Raise an exception for bad responses

    with open(file_path, "wb") as f:
        f.write(response.content)

    print(f"File downloaded and saved to: {file_path}")
Yuanhan Zhang's avatar
Yuanhan Zhang committed
138
    # Create the parser
zhyncs's avatar
zhyncs committed
139
140
141
    parser = argparse.ArgumentParser(
        description="Run video processing with specified port."
    )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
142
143

    # Add an argument for the port
zhyncs's avatar
zhyncs committed
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
    parser.add_argument(
        "--port",
        type=int,
        default=30000,
        help="The master port for distributed serving.",
    )
    parser.add_argument(
        "--chunk-idx", type=int, default=0, help="The index of the chunk to process."
    )
    parser.add_argument(
        "--num-chunks", type=int, default=8, help="The number of chunks to process."
    )
    parser.add_argument(
        "--save-dir",
        type=str,
        default="./work_dirs/llava_video",
        help="The directory to save the processed video files.",
    )
    parser.add_argument(
        "--video-dir",
        type=str,
165
        default=os.path.expanduser("~/.cache/jobs.mp4"),
zhyncs's avatar
zhyncs committed
166
167
168
169
170
171
172
173
174
175
176
177
178
179
        help="The directory or path for the processed video files.",
    )
    parser.add_argument(
        "--model-path",
        type=str,
        default="lmms-lab/LLaVA-NeXT-Video-7B",
        help="The model path for the video processing.",
    )
    parser.add_argument(
        "--num-frames",
        type=int,
        default=16,
        help="The number of frames to process in each video.",
    )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
    parser.add_argument("--mm_spatial_pool_stride", type=int, default=2)

    # Parse the arguments
    args = parser.parse_args()

    cur_port = args.port

    cur_chunk = args.chunk_idx

    num_chunks = args.num_chunks

    num_frames = args.num_frames

    if "34b" in args.model_path.lower():
        tokenizer_path = "liuhaotian/llava-v1.6-34b-tokenizer"
    elif "7b" in args.model_path.lower():
        tokenizer_path = "llava-hf/llava-1.5-7b-hf"
    else:
        print("Invalid model path. Please specify a valid model path.")
        exit()

    model_overide_args = {}

    model_overide_args["mm_spatial_pool_stride"] = args.mm_spatial_pool_stride
    model_overide_args["architectures"] = ["LlavaVidForCausalLM"]
    model_overide_args["num_frames"] = args.num_frames
    model_overide_args["model_type"] = "llava"

    if "34b" in args.model_path.lower():
        model_overide_args["image_token_index"] = 64002

    if args.num_frames == 32:
        model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"}
        model_overide_args["max_sequence_length"] = 4096 * 2
        model_overide_args["tokenizer_model_max_length"] = 4096 * 2
    elif args.num_frames < 32:
        pass
    else:
zhyncs's avatar
zhyncs committed
218
219
220
        print(
            "The maximum number of frames to process is 32. Please specify a valid number of frames."
        )
Yuanhan Zhang's avatar
Yuanhan Zhang committed
221
222
223
        exit()

    runtime = sgl.Runtime(
zhyncs's avatar
zhyncs committed
224
        model_path=args.model_path,  # "liuhaotian/llava-v1.6-vicuna-7b",
Yuanhan Zhang's avatar
Yuanhan Zhang committed
225
226
        tokenizer_path=tokenizer_path,
        port=cur_port,
zhyncs's avatar
zhyncs committed
227
        additional_ports=[cur_port + 1, cur_port + 2, cur_port + 3, cur_port + 4],
Yuanhan Zhang's avatar
Yuanhan Zhang committed
228
        model_overide_args=model_overide_args,
zhyncs's avatar
zhyncs committed
229
        tp_size=1,
Yuanhan Zhang's avatar
Yuanhan Zhang committed
230
231
232
233
234
235
236
237
238
239
240
    )
    sgl.set_default_backend(runtime)
    print(f"chat template: {runtime.endpoint.chat_template.name}")

    # Run a single request
    # try:
    print("\n========== single ==========\n")
    root = args.video_dir
    if os.path.isfile(root):
        video_files = [root]
    else:
zhyncs's avatar
zhyncs committed
241
242
243
244
245
        video_files = [
            os.path.join(root, f)
            for f in os.listdir(root)
            if f.endswith((".mp4", ".avi", ".mov"))
        ]  # Add more extensions if needed
Yuanhan Zhang's avatar
Yuanhan Zhang committed
246
247
248
249
250
251
    start_time = time.time()  # Start time for processing a single video
    for cur_video in video_files[:1]:
        print(cur_video)
        single(cur_video, num_frames)
    end_time = time.time()  # End time for processing a single video
    total_time = end_time - start_time
zhyncs's avatar
zhyncs committed
252
253
254
    average_time = total_time / len(
        video_files
    )  # Calculate the average processing time
Yuanhan Zhang's avatar
Yuanhan Zhang committed
255
256
257
258
259
260
261
262
263
264
265
    print(f"Average processing time per video: {average_time:.2f} seconds")
    runtime.shutdown()
    # except Exception as e:
    #     print(e)
    runtime.shutdown()

    # # # Run a batch of requests
    # print("\n========== batch ==========\n")
    # if not os.path.exists(args.save_dir):
    #     os.makedirs(args.save_dir)
    # batch(args.video_dir,args.save_dir,cur_chunk, num_chunks, num_frames, num_chunks)
zhyncs's avatar
zhyncs committed
266
    # runtime.shutdown()