Unverified Commit 823b4429 authored by Mick's avatar Mick Committed by GitHub
Browse files

lang: support direct video inference (#9936)


Co-authored-by: default avatarLianmin Zheng <lianminzheng@gmail.com>
parent 14a4d80e
...@@ -21,7 +21,6 @@ from sglang.bench_serving import ( ...@@ -21,7 +21,6 @@ from sglang.bench_serving import (
) )
from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path
from sglang.srt.entrypoints.openai.protocol import ChatCompletionMessageContentPart from sglang.srt.entrypoints.openai.protocol import ChatCompletionMessageContentPart
from sglang.utils import encode_video_base64
# type of content fields, can be only prompts or with images/videos # type of content fields, can be only prompts or with images/videos
MsgContent = Union[str, List[ChatCompletionMessageContentPart]] MsgContent = Union[str, List[ChatCompletionMessageContentPart]]
...@@ -325,15 +324,9 @@ def sample_nextqa_requests( ...@@ -325,15 +324,9 @@ def sample_nextqa_requests(
prompt_len = len(prompt_token_ids) prompt_len = len(prompt_token_ids)
output_len = fixed_output_len # max output len, not real output len output_len = fixed_output_len # max output len, not real output len
# video input
base64_data = encode_video_base64(video.path, video.num_frames)
# NOTE: This will be replaced by the expanded length from the server
prompt_len += video.num_frames
# add to content # add to content
content = [ content = [
{"type": "image_url", "image_url": {"url": base64_data}}, {"type": "video_url", "video_url": {"url": video}},
{"type": "text", "text": prompt}, {"type": "text", "text": prompt},
] ]
......
...@@ -33,7 +33,7 @@ ...@@ -33,7 +33,7 @@
"from sglang import assistant, function, gen, system, user\n", "from sglang import assistant, function, gen, system, user\n",
"from sglang import image\n", "from sglang import image\n",
"from sglang import RuntimeEndpoint\n", "from sglang import RuntimeEndpoint\n",
"from sglang.lang.api import set_default_backend\n", "from sglang.lang.api import set_default_backend, video\n",
"from sglang.srt.utils import load_image\n", "from sglang.srt.utils import load_image\n",
"from sglang.test.doc_patch import launch_server_cmd\n", "from sglang.test.doc_patch import launch_server_cmd\n",
"from sglang.utils import print_highlight, terminate_process, wait_for_server\n", "from sglang.utils import print_highlight, terminate_process, wait_for_server\n",
...@@ -421,7 +421,11 @@ ...@@ -421,7 +421,11 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"metadata": {}, "metadata": {
"jupyter": {
"is_executing": true
}
},
"outputs": [], "outputs": [],
"source": [ "source": [
"@function\n", "@function\n",
...@@ -436,6 +440,30 @@ ...@@ -436,6 +440,30 @@
"print_highlight(state[\"answer\"])" "print_highlight(state[\"answer\"])"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Ask a question about a video"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@function\n",
"def video_qa(s, video_file, question):\n",
" s += user(video(video_file) + question)\n",
" s += assistant(gen(\"answer\", max_tokens=256))\n",
"\n",
"\n",
"video_url = \"https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/videos/jobs_presenting_ipod.mp4\"\n",
"state = video_qa(video_url, \"What is in the video?\")\n",
"print_highlight(state[\"answer\"])"
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
......
...@@ -229,7 +229,7 @@ def image(expr: SglExpr): ...@@ -229,7 +229,7 @@ def image(expr: SglExpr):
return SglImage(expr) return SglImage(expr)
def video(path: str, num_frames: int): def video(path: str, num_frames: int = -1):
return SglVideo(path, num_frames) return SglVideo(path, num_frames)
......
...@@ -104,6 +104,7 @@ class RuntimeEndpoint(BaseBackend): ...@@ -104,6 +104,7 @@ class RuntimeEndpoint(BaseBackend):
def commit_lazy_operations(self, s: StreamExecutor): def commit_lazy_operations(self, s: StreamExecutor):
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}} data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data) self._add_images(s, data)
self._add_videos(s, data)
res = http_request( res = http_request(
self.base_url + "/generate", self.base_url + "/generate",
json=data, json=data,
...@@ -115,6 +116,7 @@ class RuntimeEndpoint(BaseBackend): ...@@ -115,6 +116,7 @@ class RuntimeEndpoint(BaseBackend):
def fill_image(self, s: StreamExecutor): def fill_image(self, s: StreamExecutor):
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}} data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data) self._add_images(s, data)
res = http_request( res = http_request(
self.base_url + "/generate", self.base_url + "/generate",
json=data, json=data,
...@@ -181,6 +183,7 @@ class RuntimeEndpoint(BaseBackend): ...@@ -181,6 +183,7 @@ class RuntimeEndpoint(BaseBackend):
data[item] = value data[item] = value
self._add_images(s, data) self._add_images(s, data)
self._add_videos(s, data)
res = http_request( res = http_request(
self.base_url + "/generate", self.base_url + "/generate",
...@@ -222,6 +225,7 @@ class RuntimeEndpoint(BaseBackend): ...@@ -222,6 +225,7 @@ class RuntimeEndpoint(BaseBackend):
data["stream"] = True data["stream"] = True
self._add_images(s, data) self._add_images(s, data)
self._add_videos(s, data)
res = http_request( res = http_request(
self.base_url + "/generate", self.base_url + "/generate",
...@@ -324,6 +328,8 @@ class RuntimeEndpoint(BaseBackend): ...@@ -324,6 +328,8 @@ class RuntimeEndpoint(BaseBackend):
def _generate_http_request(self, s: StreamExecutor, data): def _generate_http_request(self, s: StreamExecutor, data):
self._add_images(s, data) self._add_images(s, data)
self._add_videos(s, data)
res = http_request( res = http_request(
self.base_url + "/generate", self.base_url + "/generate",
json=data, json=data,
...@@ -338,6 +344,11 @@ class RuntimeEndpoint(BaseBackend): ...@@ -338,6 +344,11 @@ class RuntimeEndpoint(BaseBackend):
assert len(s.images_) == 1, "Only support one image." assert len(s.images_) == 1, "Only support one image."
data["image_data"] = s.images_[0][1] data["image_data"] = s.images_[0][1]
def _add_videos(self, s: StreamExecutor, data):
if s.videos_:
assert len(s.videos_) == 1, "Only support one video."
data["video_data"] = s.videos_
def _assert_success(self, res): def _assert_success(self, res):
if res.status_code != 200: if res.status_code != 200:
try: try:
......
...@@ -16,7 +16,9 @@ class ChatTemplate: ...@@ -16,7 +16,9 @@ class ChatTemplate:
role_prefix_and_suffix: Dict[str, Tuple[str, str]] role_prefix_and_suffix: Dict[str, Tuple[str, str]]
stop_str: List[str] = () stop_str: List[str] = ()
image_token: str = "<image>" image_token: str = "<image>"
video_token: str = "<video>"
audio_token: str = "<audio>" audio_token: str = "<audio>"
style: ChatTemplateStyle = ChatTemplateStyle.PLAIN style: ChatTemplateStyle = ChatTemplateStyle.PLAIN
def get_prefix_and_suffix( def get_prefix_and_suffix(
...@@ -161,6 +163,7 @@ register_chat_template( ...@@ -161,6 +163,7 @@ register_chat_template(
style=ChatTemplateStyle.PLAIN, style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",), stop_str=("<|im_end|>",),
image_token="<|vision_start|><|image_pad|><|vision_end|>", image_token="<|vision_start|><|image_pad|><|vision_end|>",
video_token="<|vision_start|><|video_pad|><|vision_end|>",
) )
) )
......
...@@ -32,11 +32,7 @@ from sglang.lang.ir import ( ...@@ -32,11 +32,7 @@ from sglang.lang.ir import (
SglVarScopeEnd, SglVarScopeEnd,
SglVideo, SglVideo,
) )
from sglang.utils import ( from sglang.utils import encode_image_base64, get_exception_traceback
encode_image_base64,
encode_video_base64,
get_exception_traceback,
)
def run_internal(state, program, func_args, func_kwargs, sync): def run_internal(state, program, func_args, func_kwargs, sync):
...@@ -286,6 +282,7 @@ class StreamExecutor: ...@@ -286,6 +282,7 @@ class StreamExecutor:
# For vision # For vision
self.images_ = [] self.images_ = []
self.cur_images = [] self.cur_images = []
self.videos_ = []
# For fork/join # For fork/join
self.fork_start_text_pos = None self.fork_start_text_pos = None
...@@ -372,6 +369,7 @@ class StreamExecutor: ...@@ -372,6 +369,7 @@ class StreamExecutor:
exes[i].cur_role_begin_pos = self.cur_role_begin_pos exes[i].cur_role_begin_pos = self.cur_role_begin_pos
exes[i].fork_start_text_pos = len(self.text_) exes[i].fork_start_text_pos = len(self.text_)
exes[i].images_ = list(self.images_) exes[i].images_ = list(self.images_)
exes[i].videos_ = list(self.videos_)
# TODO(ying): handle API speculative execution # TODO(ying): handle API speculative execution
...@@ -508,13 +506,9 @@ class StreamExecutor: ...@@ -508,13 +506,9 @@ class StreamExecutor:
def _execute_video(self, expr: SglVideo): def _execute_video(self, expr: SglVideo):
path = expr.path path = expr.path
num_frames = expr.num_frames
base64_data = encode_video_base64(path, num_frames)
self.images_.append((path, base64_data)) self.videos_.append(path)
self.cur_images.append((path, base64_data)) self.text_ += self.chat_template.video_token
self.text_ += self.chat_template.image_token
def _spec_gen(self, sampling_params): def _spec_gen(self, sampling_params):
stop = sampling_params.stop stop = sampling_params.stop
......
...@@ -445,7 +445,7 @@ class SglImage(SglExpr): ...@@ -445,7 +445,7 @@ class SglImage(SglExpr):
class SglVideo(SglExpr): class SglVideo(SglExpr):
def __init__(self, path: str, num_frames: int): def __init__(self, path: str, num_frames: int = -1):
self.path = path self.path = path
self.num_frames = num_frames self.num_frames = num_frames
......
...@@ -209,52 +209,6 @@ def encode_frame(frame): ...@@ -209,52 +209,6 @@ def encode_frame(frame):
return frame_bytes return frame_bytes
def encode_video_base64(video_path: str, num_frames: int = 16):
import cv2 # pip install opencv-python-headless
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise IOError(f"Could not open video file:{video_path}")
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
print(f"target_frames: {num_frames}")
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
frames = []
for _ in range(total_frames):
ret, frame = cap.read()
if ret:
frames.append(frame)
else:
# Handle the case where the frame could not be read
# print(f"Warning: Could not read frame at index {i}.")
pass
cap.release()
# Safely select frames based on frame_indices, avoiding IndexError
frames = [frames[i] for i in frame_indices if i < len(frames)]
# If there are not enough frames, duplicate the last frame until we reach the target
while len(frames) < num_frames:
frames.append(frames[-1])
# Use ThreadPoolExecutor to process and encode frames in parallel
with ThreadPoolExecutor() as executor:
encoded_frames = list(executor.map(encode_frame, frames))
# encoded_frames = list(map(encode_frame, frames))
# Concatenate all frames bytes
video_bytes = b"".join(encoded_frames)
# Encode the concatenated bytes to base64
video_base64 = "video:" + pybase64.b64encode(video_bytes).decode("utf-8")
return video_base64
def _is_chinese_char(cp: int): def _is_chinese_char(cp: int):
"""Checks whether CP is the codepoint of a CJK character.""" """Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block: # This defines a "chinese character" as anything in the CJK Unicode block:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment