You need to sign in or sign up before continuing.
Unverified Commit aa47f642 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Revert "[feat] Enable chunked prefill for llava-onevision" (#2329)

parent 3ddb1c46
...@@ -128,7 +128,6 @@ class ImageInputs: ...@@ -128,7 +128,6 @@ class ImageInputs:
image_hashes: Optional[list] = None image_hashes: Optional[list] = None
image_sizes: Optional[list] = None image_sizes: Optional[list] = None
image_offsets: Optional[list] = None image_offsets: Optional[list] = None
image_pad_len: Optional[list] = None
pad_values: Optional[list] = None pad_values: Optional[list] = None
modalities: Optional[list] = None modalities: Optional[list] = None
num_image_tokens: Optional[int] = None num_image_tokens: Optional[int] = None
......
...@@ -111,20 +111,15 @@ class ModelRunner: ...@@ -111,20 +111,15 @@ class ModelRunner:
) )
if self.is_multimodal: if self.is_multimodal:
logger.info(
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
)
server_args.chunked_prefill_size = -1
self.mem_fraction_static *= 0.95 self.mem_fraction_static *= 0.95
if self.model_config.hf_config.architectures == [
"MllamaForConditionalGeneration"
]:
logger.info("Automatically turn off --chunked-prefill-size for mllama.")
server_args.chunked_prefill_size = -1
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
if self.model_config.hf_config.architectures == [ if self.model_config.hf_config.architectures == [
"Qwen2VLForConditionalGeneration" "Qwen2VLForConditionalGeneration"
]: ]:
logger.info(
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
)
server_args.chunked_prefill_size = -1
server_args.disable_radix_cache = True server_args.disable_radix_cache = True
# Global vars # Global vars
......
...@@ -57,7 +57,6 @@ class LlavaBaseForCausalLM(nn.Module): ...@@ -57,7 +57,6 @@ class LlavaBaseForCausalLM(nn.Module):
else: else:
image_aspect_ratio = "anyres" image_aspect_ratio = "anyres"
offset_list = [] offset_list = []
image_inputs.image_pad_len = []
for image_idx, image_s in enumerate(image_sizes): for image_idx, image_s in enumerate(image_sizes):
if len(image_sizes) > 16: if len(image_sizes) > 16:
# 2x2 pooling with stride 2 # 2x2 pooling with stride 2
...@@ -104,7 +103,6 @@ class LlavaBaseForCausalLM(nn.Module): ...@@ -104,7 +103,6 @@ class LlavaBaseForCausalLM(nn.Module):
+ input_ids[offset + 1 :] + input_ids[offset + 1 :]
) )
offset_list.append(offset) offset_list.append(offset)
image_inputs.image_pad_len.append(new_image_feature_len)
image_inputs.image_offsets = offset_list image_inputs.image_offsets = offset_list
return input_ids return input_ids
...@@ -136,14 +134,6 @@ class LlavaBaseForCausalLM(nn.Module): ...@@ -136,14 +134,6 @@ class LlavaBaseForCausalLM(nn.Module):
image_inputs = forward_batch.image_inputs image_inputs = forward_batch.image_inputs
if forward_batch.forward_mode.is_extend(): if forward_batch.forward_mode.is_extend():
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
# Embed text inputs
input_embeds = self.language_model.model.embed_tokens(input_ids)
# Got List[List[str]] extend it to List[str] # Got List[List[str]] extend it to List[str]
# The length of the List should be equal to batch size # The length of the List should be equal to batch size
modalities_list = [] modalities_list = []
...@@ -152,12 +142,18 @@ class LlavaBaseForCausalLM(nn.Module): ...@@ -152,12 +142,18 @@ class LlavaBaseForCausalLM(nn.Module):
if im and im.modalities is not None: if im and im.modalities is not None:
modalities_list.extend(im.modalities) modalities_list.extend(im.modalities)
if im and im.image_offsets: if im and im.image_offsets:
max_image_offset.append( max_image_offset.append(max(im.image_offsets))
np.max(np.array(im.image_offsets) + np.array(im.image_pad_len))
)
else: else:
max_image_offset.append(-1) max_image_offset.append(-1)
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
# Embed text inputs
input_embeds = self.language_model.model.embed_tokens(input_ids)
start_positions = positions[forward_batch.extend_start_loc].cpu().numpy() start_positions = positions[forward_batch.extend_start_loc].cpu().numpy()
need_vision = start_positions <= np.array(max_image_offset) need_vision = start_positions <= np.array(max_image_offset)
...@@ -354,7 +350,6 @@ class LlavaBaseForCausalLM(nn.Module): ...@@ -354,7 +350,6 @@ class LlavaBaseForCausalLM(nn.Module):
# Fill in the placeholder for the image # Fill in the placeholder for the image
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy() extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
extend_seq_lens = forward_batch.extend_seq_lens.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
pt = 0 pt = 0
for i in range(bs): for i in range(bs):
...@@ -362,36 +357,18 @@ class LlavaBaseForCausalLM(nn.Module): ...@@ -362,36 +357,18 @@ class LlavaBaseForCausalLM(nn.Module):
continue continue
start_idx = extend_start_loc_cpu[i] start_idx = extend_start_loc_cpu[i]
seq_len = extend_seq_lens[i]
prefix_len = prefix_lens_cpu[i] prefix_len = prefix_lens_cpu[i]
# Multiple images # Multiple images
for image_idx, image_offset in enumerate( for j, image_offset in enumerate(image_inputs[i].image_offsets):
image_inputs[i].image_offsets if image_offset < prefix_len:
):
if (
image_offset + image_inputs[i].image_pad_len[image_idx]
<= prefix_len
):
continue continue
if image_offset >= prefix_len + seq_len:
break
tmp_image_feature = image_features[pt][image_idx] tmp_image_feature = image_features[pt][j]
pad_len = tmp_image_feature.shape[0] pad_len = tmp_image_feature.shape[0]
input_offset = image_offset - prefix_len left_idx = start_idx + (image_offset - prefix_len)
left_idx = start_idx + input_offset right_idx = start_idx + (image_offset - prefix_len) + pad_len
right_idx = left_idx + pad_len
assert right_idx > start_idx
if input_offset < 0:
left_idx = start_idx
tmp_image_feature = tmp_image_feature[-input_offset:]
if right_idx > start_idx + seq_len:
tmp_image_feature = tmp_image_feature[
: start_idx + seq_len - right_idx
]
right_idx = start_idx + seq_len
try: try:
input_embeds[left_idx:right_idx] = tmp_image_feature input_embeds[left_idx:right_idx] = tmp_image_feature
except RuntimeError as e: except RuntimeError as e:
......
...@@ -39,7 +39,6 @@ suites = { ...@@ -39,7 +39,6 @@ suites = {
"test_triton_attention_kernels.py", "test_triton_attention_kernels.py",
"test_triton_attention_backend.py", "test_triton_attention_backend.py",
"test_update_weights_from_disk.py", "test_update_weights_from_disk.py",
"test_vision_chunked_prefill.py",
"test_vision_openai_server.py", "test_vision_openai_server.py",
"test_session_control.py", "test_session_control.py",
], ],
......
"""
Usage:
python3 -m unittest test_vision_chunked_prefill.TestVisionChunkedPrefill.test_chunked_prefill
"""
import base64
import io
import os
import unittest
from concurrent.futures import ThreadPoolExecutor
from typing import Union
import numpy as np
import requests
from decord import VideoReader, cpu
from PIL import Image
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
class TestVisionChunkedPrefill(unittest.TestCase):
def prepare_video_messages(self, video_path, max_frames_num=8):
vr = VideoReader(video_path, ctx=cpu(0))
total_frame_num = len(vr)
uniform_sampled_frames = np.linspace(
0, total_frame_num - 1, max_frames_num, dtype=int
)
frame_idx = uniform_sampled_frames.tolist()
frames = vr.get_batch(frame_idx).asnumpy()
base64_frames = []
for frame in frames:
pil_img = Image.fromarray(frame)
buff = io.BytesIO()
pil_img.save(buff, format="JPEG")
base64_str = base64.b64encode(buff.getvalue()).decode("utf-8")
base64_frames.append(base64_str)
messages = [{"role": "user", "content": []}]
frame_format = {
"type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,{}"},
"modalities": "video",
}
for base64_frame in base64_frames:
frame_format["image_url"]["url"] = "data:image/jpeg;base64,{}".format(
base64_frame
)
messages[0]["content"].append(frame_format.copy())
prompt = {"type": "text", "text": "Please describe the video briefly."}
messages[0]["content"].append(prompt)
return messages
def get_prompt_from_messages(self, messages):
text = (
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n"
)
image_data = []
for content in messages[0]["content"]:
if content["type"] == "image_url":
text += "<image>\n"
image_data.append(content["image_url"]["url"])
text += "Please describe the video briefly.<|im_end|>\n<|im_start|>assistant\n"
return text, image_data
def generate(self, text, image_data):
response = requests.post(
self.base_url + "/generate",
json={
"text": text,
"image_data": image_data,
"sampling_params": {
"temperature": 0,
"max_new_tokens": 32,
"no_stop_trim": True,
"skip_special_tokens": False,
},
"modalities": ["multi-images"],
},
).json()
return response["text"]
def generate_for_video(self, batch, num_frame) -> Union[str, list[str]]:
# prepare the video input about Steven introducing ipod nano
url = "https://raw.githubusercontent.com/evolvinglmms-lab/sglang/dev/onevision_local/assets/jobs.mp4"
cache_dir = os.path.expanduser("~/.cache")
file_path = os.path.join(cache_dir, "jobs.mp4")
os.makedirs(cache_dir, exist_ok=True)
if not os.path.exists(file_path):
response = requests.get(url)
response.raise_for_status()
with open(file_path, "wb") as f:
f.write(response.content)
if not batch:
assert isinstance(num_frame, int)
messages = self.prepare_video_messages(file_path, max_frames_num=num_frame)
text, image_data = self.get_prompt_from_messages(messages)
return self.generate(text, image_data)
else:
assert isinstance(num_frame, list)
func_args = []
for max_frames_num in num_frame:
messages = self.prepare_video_messages(
file_path,
max_frames_num=max_frames_num,
)
text, image_data = self.get_prompt_from_messages(messages)
func_args.append((text, image_data))
with ThreadPoolExecutor(max_workers=10) as executor:
responses = list(executor.map(lambda p: self.generate(*p), func_args))
return responses
def run_generate(self, chunked_prefill_size, batch, num_frame):
# launch server
model = "lmms-lab/llava-onevision-qwen2-7b-ov"
# model = "meta-llama/Llama-3.2-11B-Vision-Instruct"
self.base_url = DEFAULT_URL_FOR_TEST
process = popen_launch_server(
model,
self.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--chunked-prefill-size",
f"{chunked_prefill_size}",
],
)
try:
return self.generate_for_video(batch, num_frame)
finally:
kill_process_tree(process.pid)
def test_chunked_prefill(self):
output_chunked = self.run_generate(
chunked_prefill_size=1024, batch=False, num_frame=1
)
output_no_chunked = self.run_generate(
chunked_prefill_size=-1, batch=False, num_frame=1
)
print("output with chunked prefill:")
print(output_chunked)
print("output without chunked prefill:")
print(output_no_chunked)
assert output_chunked == output_no_chunked
output_chunked = self.run_generate(
chunked_prefill_size=1024, batch=True, num_frame=[2, 6, 8, 10]
)
output_no_chunked = self.run_generate(
chunked_prefill_size=-1, batch=True, num_frame=[2, 6, 8, 10]
)
print("output with chunked prefill:")
print(output_chunked)
print("output without chunked prefill:")
print(output_no_chunked)
assert output_chunked == output_no_chunked
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment