Unverified Commit ffc722a6 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Revert "lang: support direct video inference" (#12038)

parent 49afb3d9
......@@ -21,6 +21,7 @@ from sglang.bench_serving import (
)
from sglang.lang.chat_template import get_chat_template, get_chat_template_by_model_path
from sglang.srt.entrypoints.openai.protocol import ChatCompletionMessageContentPart
from sglang.utils import encode_video_base64
# type of content fields, can be only prompts or with images/videos
MsgContent = Union[str, List[ChatCompletionMessageContentPart]]
......@@ -324,9 +325,15 @@ def sample_nextqa_requests(
prompt_len = len(prompt_token_ids)
output_len = fixed_output_len # max output len, not real output len
# video input
base64_data = encode_video_base64(video.path, video.num_frames)
# NOTE: This will be replaced by the expanded length from the server
prompt_len += video.num_frames
# add to content
content = [
{"type": "video_url", "video_url": {"url": video}},
{"type": "image_url", "image_url": {"url": base64_data}},
{"type": "text", "text": prompt},
]
......
......@@ -33,7 +33,7 @@
"from sglang import assistant, function, gen, system, user\n",
"from sglang import image\n",
"from sglang import RuntimeEndpoint\n",
"from sglang.lang.api import set_default_backend, video\n",
"from sglang.lang.api import set_default_backend\n",
"from sglang.srt.utils import load_image\n",
"from sglang.test.doc_patch import launch_server_cmd\n",
"from sglang.utils import print_highlight, terminate_process, wait_for_server\n",
......@@ -421,11 +421,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"jupyter": {
"is_executing": true
}
},
"metadata": {},
"outputs": [],
"source": [
"@function\n",
......@@ -440,30 +436,6 @@
"print_highlight(state[\"answer\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Ask a question about a video"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@function\n",
"def video_qa(s, video_file, question):\n",
" s += user(video(video_file) + question)\n",
" s += assistant(gen(\"answer\", max_tokens=256))\n",
"\n",
"\n",
"video_url = \"https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/videos/jobs_presenting_ipod.mp4\"\n",
"state = video_qa(video_url, \"What is in the video?\")\n",
"print_highlight(state[\"answer\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
......
......@@ -229,7 +229,7 @@ def image(expr: SglExpr):
return SglImage(expr)
def video(path: str, num_frames: int = -1):
def video(path: str, num_frames: int):
return SglVideo(path, num_frames)
......
......@@ -104,7 +104,6 @@ class RuntimeEndpoint(BaseBackend):
def commit_lazy_operations(self, s: StreamExecutor):
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data)
self._add_videos(s, data)
res = http_request(
self.base_url + "/generate",
json=data,
......@@ -116,7 +115,6 @@ class RuntimeEndpoint(BaseBackend):
def fill_image(self, s: StreamExecutor):
data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
self._add_images(s, data)
res = http_request(
self.base_url + "/generate",
json=data,
......@@ -183,7 +181,6 @@ class RuntimeEndpoint(BaseBackend):
data[item] = value
self._add_images(s, data)
self._add_videos(s, data)
res = http_request(
self.base_url + "/generate",
......@@ -225,7 +222,6 @@ class RuntimeEndpoint(BaseBackend):
data["stream"] = True
self._add_images(s, data)
self._add_videos(s, data)
res = http_request(
self.base_url + "/generate",
......@@ -328,8 +324,6 @@ class RuntimeEndpoint(BaseBackend):
def _generate_http_request(self, s: StreamExecutor, data):
self._add_images(s, data)
self._add_videos(s, data)
res = http_request(
self.base_url + "/generate",
json=data,
......@@ -344,11 +338,6 @@ class RuntimeEndpoint(BaseBackend):
assert len(s.images_) == 1, "Only support one image."
data["image_data"] = s.images_[0][1]
def _add_videos(self, s: StreamExecutor, data):
if s.videos_:
assert len(s.videos_) == 1, "Only support one video."
data["video_data"] = s.videos_
def _assert_success(self, res):
if res.status_code != 200:
try:
......
......@@ -16,9 +16,7 @@ class ChatTemplate:
role_prefix_and_suffix: Dict[str, Tuple[str, str]]
stop_str: List[str] = ()
image_token: str = "<image>"
video_token: str = "<video>"
audio_token: str = "<audio>"
style: ChatTemplateStyle = ChatTemplateStyle.PLAIN
def get_prefix_and_suffix(
......@@ -163,7 +161,6 @@ register_chat_template(
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",),
image_token="<|vision_start|><|image_pad|><|vision_end|>",
video_token="<|vision_start|><|video_pad|><|vision_end|>",
)
)
......
......@@ -32,7 +32,11 @@ from sglang.lang.ir import (
SglVarScopeEnd,
SglVideo,
)
from sglang.utils import encode_image_base64, get_exception_traceback
from sglang.utils import (
encode_image_base64,
encode_video_base64,
get_exception_traceback,
)
def run_internal(state, program, func_args, func_kwargs, sync):
......@@ -282,7 +286,6 @@ class StreamExecutor:
# For vision
self.images_ = []
self.cur_images = []
self.videos_ = []
# For fork/join
self.fork_start_text_pos = None
......@@ -369,7 +372,6 @@ class StreamExecutor:
exes[i].cur_role_begin_pos = self.cur_role_begin_pos
exes[i].fork_start_text_pos = len(self.text_)
exes[i].images_ = list(self.images_)
exes[i].videos_ = list(self.videos_)
# TODO(ying): handle API speculative execution
......@@ -506,9 +508,13 @@ class StreamExecutor:
def _execute_video(self, expr: SglVideo):
path = expr.path
num_frames = expr.num_frames
base64_data = encode_video_base64(path, num_frames)
self.videos_.append(path)
self.text_ += self.chat_template.video_token
self.images_.append((path, base64_data))
self.cur_images.append((path, base64_data))
self.text_ += self.chat_template.image_token
def _spec_gen(self, sampling_params):
stop = sampling_params.stop
......
......@@ -445,7 +445,7 @@ class SglImage(SglExpr):
class SglVideo(SglExpr):
def __init__(self, path: str, num_frames: int = -1):
def __init__(self, path: str, num_frames: int):
self.path = path
self.num_frames = num_frames
......
......@@ -209,6 +209,52 @@ def encode_frame(frame):
return frame_bytes
def encode_video_base64(video_path: str, num_frames: int = 16):
import cv2 # pip install opencv-python-headless
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
raise IOError(f"Could not open video file:{video_path}")
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
print(f"target_frames: {num_frames}")
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
frames = []
for _ in range(total_frames):
ret, frame = cap.read()
if ret:
frames.append(frame)
else:
# Handle the case where the frame could not be read
# print(f"Warning: Could not read frame at index {i}.")
pass
cap.release()
# Safely select frames based on frame_indices, avoiding IndexError
frames = [frames[i] for i in frame_indices if i < len(frames)]
# If there are not enough frames, duplicate the last frame until we reach the target
while len(frames) < num_frames:
frames.append(frames[-1])
# Use ThreadPoolExecutor to process and encode frames in parallel
with ThreadPoolExecutor() as executor:
encoded_frames = list(executor.map(encode_frame, frames))
# encoded_frames = list(map(encode_frame, frames))
# Concatenate all frames bytes
video_bytes = b"".join(encoded_frames)
# Encode the concatenated bytes to base64
video_base64 = "video:" + pybase64.b64encode(video_bytes).decode("utf-8")
return video_base64
def _is_chinese_char(cp: int):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment