Unverified Commit 823b4429 authored by Mick's avatar Mick Committed by GitHub
Browse files

lang: support direct video inference (#9936)


Co-authored-by: default avatarLianmin Zheng <lianminzheng@gmail.com>
parent 14a4d80e
......@@ -21,7 +21,6 @@ from sglang.bench_serving import (
)
from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path
from sglang.srt.entrypoints.openai.protocol import ChatCompletionMessageContentPart
from sglang.utils import encode_video_base64
# type of content fields, can be only prompts or with images/videos
MsgContent = Union[str, List[ChatCompletionMessageContentPart]]
......@@ -325,15 +324,9 @@ def sample_nextqa_requests(
prompt_len = len(prompt_token_ids)
output_len = fixed_output_len # max output len, not real output len
# video input
base64_data = encode_video_base64(video.path, video.num_frames)
# NOTE: This will be replaced by the expanded length from the server
prompt_len += video.num_frames
# add to content
content = [
{"type": "image_url", "image_url": {"url": base64_data}},
{"type": "video_url", "video_url": {"url": video}},
{"type": "text", "text": prompt},
]
......
......@@ -33,7 +33,7 @@
"from sglang import assistant, function, gen, system, user\n",
"from sglang import image\n",
"from sglang import RuntimeEndpoint\n",
"from sglang.lang.api import set_default_backend\n",
"from sglang.lang.api import set_default_backend, video\n",
"from sglang.srt.utils import load_image\n",
"from sglang.test.doc_patch import launch_server_cmd\n",
"from sglang.utils import print_highlight, terminate_process, wait_for_server\n",
......@@ -421,7 +421,11 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"jupyter": {
"is_executing": true
}
},
"outputs": [],
"source": [
"@function\n",
......@@ -436,6 +440,30 @@
"print_highlight(state[\"answer\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Ask a question about a video"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@function\n",
"def video_qa(s, video_file, question):\n",
" s += user(video(video_file) + question)\n",
" s += assistant(gen(\"answer\", max_tokens=256))\n",
"\n",
"\n",
"video_url = \"https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/videos/jobs_presenting_ipod.mp4\"\n",
"state = video_qa(video_url, \"What is in the video?\")\n",
"print_highlight(state[\"answer\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
......
......@@ -229,7 +229,7 @@ def image(expr: SglExpr):
return SglImage(expr)
def video(path: str, num_frames: int):
def video(path: str, num_frames: int = -1):
return SglVideo(path, num_frames)
......
......@@ -104,6 +104,7 @@ class RuntimeEndpoint(BaseBackend):
def commit_lazy_operations(self, s: StreamExecutor):
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data)
self._add_videos(s, data)
res = http_request(
self.base_url + "/generate",
json=data,
......@@ -115,6 +116,7 @@ class RuntimeEndpoint(BaseBackend):
def fill_image(self, s: StreamExecutor):
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data)
res = http_request(
self.base_url + "/generate",
json=data,
......@@ -181,6 +183,7 @@ class RuntimeEndpoint(BaseBackend):
data[item] = value
self._add_images(s, data)
self._add_videos(s, data)
res = http_request(
self.base_url + "/generate",
......@@ -222,6 +225,7 @@ class RuntimeEndpoint(BaseBackend):
data["stream"] = True
self._add_images(s, data)
self._add_videos(s, data)
res = http_request(
self.base_url + "/generate",
......@@ -324,6 +328,8 @@ class RuntimeEndpoint(BaseBackend):
def _generate_http_request(self, s: StreamExecutor, data):
self._add_images(s, data)
self._add_videos(s, data)
res = http_request(
self.base_url + "/generate",
json=data,
......@@ -338,6 +344,11 @@ class RuntimeEndpoint(BaseBackend):
assert len(s.images_) == 1, "Only support one image."
data["image_data"] = s.images_[0][1]
def _add_videos(self, s: StreamExecutor, data):
if s.videos_:
assert len(s.videos_) == 1, "Only support one video."
data["video_data"] = s.videos_
def _assert_success(self, res):
if res.status_code != 200:
try:
......
......@@ -16,7 +16,9 @@ class ChatTemplate:
role_prefix_and_suffix: Dict[str, Tuple[str, str]]
stop_str: List[str] = ()
image_token: str = "<image>"
video_token: str = "<video>"
audio_token: str = "<audio>"
style: ChatTemplateStyle = ChatTemplateStyle.PLAIN
def get_prefix_and_suffix(
......@@ -161,6 +163,7 @@ register_chat_template(
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",),
image_token="<|vision_start|><|image_pad|><|vision_end|>",
video_token="<|vision_start|><|video_pad|><|vision_end|>",
)
)
......
......@@ -32,11 +32,7 @@ from sglang.lang.ir import (
SglVarScopeEnd,
SglVideo,
)
from sglang.utils import (
encode_image_base64,
encode_video_base64,
get_exception_traceback,
)
from sglang.utils import encode_image_base64, get_exception_traceback
def run_internal(state, program, func_args, func_kwargs, sync):
......@@ -286,6 +282,7 @@ class StreamExecutor:
# For vision
self.images_ = []
self.cur_images = []
self.videos_ = []
# For fork/join
self.fork_start_text_pos = None
......@@ -372,6 +369,7 @@ class StreamExecutor:
exes[i].cur_role_begin_pos = self.cur_role_begin_pos
exes[i].fork_start_text_pos = len(self.text_)
exes[i].images_ = list(self.images_)
exes[i].videos_ = list(self.videos_)
# TODO(ying): handle API speculative execution
......@@ -508,13 +506,9 @@ class StreamExecutor:
def _execute_video(self, expr: SglVideo):
path = expr.path
num_frames = expr.num_frames
base64_data = encode_video_base64(path, num_frames)
self.images_.append((path, base64_data))
self.cur_images.append((path, base64_data))
self.text_ += self.chat_template.image_token
self.videos_.append(path)
self.text_ += self.chat_template.video_token
def _spec_gen(self, sampling_params):
stop = sampling_params.stop
......
......@@ -445,7 +445,7 @@ class SglImage(SglExpr):
class SglVideo(SglExpr):
def __init__(self, path: str, num_frames: int):
def __init__(self, path: str, num_frames: int = -1):
self.path = path
self.num_frames = num_frames
......
......@@ -209,52 +209,6 @@ def encode_frame(frame):
return frame_bytes
def encode_video_base64(video_path: str, num_frames: int = 16):
import cv2 # pip install opencv-python-headless
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise IOError(f"Could not open video file:{video_path}")
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
print(f"target_frames: {num_frames}")
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
frames = []
for _ in range(total_frames):
ret, frame = cap.read()
if ret:
frames.append(frame)
else:
# Handle the case where the frame could not be read
# print(f"Warning: Could not read frame at index {i}.")
pass
cap.release()
# Safely select frames based on frame_indices, avoiding IndexError
frames = [frames[i] for i in frame_indices if i < len(frames)]
# If there are not enough frames, duplicate the last frame until we reach the target
while len(frames) < num_frames:
frames.append(frames[-1])
# Use ThreadPoolExecutor to process and encode frames in parallel
with ThreadPoolExecutor() as executor:
encoded_frames = list(executor.map(encode_frame, frames))
# encoded_frames = list(map(encode_frame, frames))
# Concatenate all frames bytes
video_bytes = b"".join(encoded_frames)
# Encode the concatenated bytes to base64
video_base64 = "video:" + pybase64.b64encode(video_bytes).decode("utf-8")
return video_base64
def _is_chinese_char(cp: int):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment