Unverified Commit 37c8a576 authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

[feat] Support session control for vision language models (#2210)

parent c754652f
...@@ -131,6 +131,7 @@ class LlavaImageProcessor(BaseImageProcessor): ...@@ -131,6 +131,7 @@ class LlavaImageProcessor(BaseImageProcessor):
if not image_data: if not image_data:
return None return None
modalities = request_obj.modalities or ["image"]
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None) aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
grid_pinpoints = ( grid_pinpoints = (
self.hf_config.image_grid_pinpoints self.hf_config.image_grid_pinpoints
...@@ -139,9 +140,12 @@ class LlavaImageProcessor(BaseImageProcessor): ...@@ -139,9 +140,12 @@ class LlavaImageProcessor(BaseImageProcessor):
else None else None
) )
if isinstance(image_data, str):
image_data = [image_data]
if isinstance(image_data, list) and len(image_data) > 0: if isinstance(image_data, list) and len(image_data) > 0:
# Multiple images if "multi-images" in modalities or "video" in modalities:
if len(image_data) > 1: # Multiple images
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
pixel_values, image_hashes, image_sizes = [], [], [] pixel_values, image_hashes, image_sizes = [], [], []
res = [] res = []
...@@ -166,13 +170,6 @@ class LlavaImageProcessor(BaseImageProcessor): ...@@ -166,13 +170,6 @@ class LlavaImageProcessor(BaseImageProcessor):
) )
image_hashes = [image_hash] image_hashes = [image_hash]
image_sizes = [image_size] image_sizes = [image_size]
elif isinstance(image_data, str):
# A single image
pixel_values, image_hash, image_size = await self._process_single_image(
image_data, aspect_ratio, grid_pinpoints
)
image_hashes = [image_hash]
image_sizes = [image_size]
else: else:
raise ValueError(f"Invalid image data: {image_data}") raise ValueError(f"Invalid image data: {image_data}")
......
...@@ -31,6 +31,7 @@ import dataclasses ...@@ -31,6 +31,7 @@ import dataclasses
import logging import logging
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import numpy as np
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
...@@ -167,6 +168,30 @@ class ImageInputs: ...@@ -167,6 +168,30 @@ class ImageInputs:
return ret return ret
def merge(self, other, vocab_size):
assert self.pixel_values.shape[1:] == other.pixel_values.shape[1:]
self.pixel_values = np.concatenate([self.pixel_values, other.pixel_values])
self.image_hashes += other.image_hashes
self.pad_values = [
(self.image_hashes) % vocab_size,
(self.image_hashes >> 16) % vocab_size,
(self.image_hashes >> 32) % vocab_size,
(self.image_hashes >> 64) % vocab_size,
]
optional_args = [
"image_sizes",
"image_offsets",
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
"aspect_ratio_ids",
"aspect_ratio_mask",
"image_grid_thws",
]
for arg in optional_args:
if getattr(self, arg, None) is not None:
setattr(self, arg, getattr(self, arg) + getattr(other, arg))
class Req: class Req:
"""The input and output status of a request.""" """The input and output status of a request."""
...@@ -177,6 +202,7 @@ class Req: ...@@ -177,6 +202,7 @@ class Req:
origin_input_text: str, origin_input_text: str,
origin_input_ids: Tuple[int], origin_input_ids: Tuple[int],
sampling_params: SamplingParams, sampling_params: SamplingParams,
origin_input_ids_unpadded: Optional[Tuple[int]] = None,
lora_path: Optional[str] = None, lora_path: Optional[str] = None,
input_embeds: Optional[List[List[float]]] = None, input_embeds: Optional[List[List[float]]] = None,
session_id: Optional[str] = None, session_id: Optional[str] = None,
...@@ -184,7 +210,11 @@ class Req: ...@@ -184,7 +210,11 @@ class Req:
# Input and output info # Input and output info
self.rid = rid self.rid = rid
self.origin_input_text = origin_input_text self.origin_input_text = origin_input_text
self.origin_input_ids_unpadded = origin_input_ids # Before image padding self.origin_input_ids_unpadded = (
origin_input_ids_unpadded
if origin_input_ids_unpadded
else origin_input_ids # Before image padding
)
self.origin_input_ids = origin_input_ids self.origin_input_ids = origin_input_ids
self.output_ids = [] # Each decode stage's output ids self.output_ids = [] # Each decode stage's output ids
self.fill_ids = None # fill_ids = origin_input_ids + output_ids self.fill_ids = None # fill_ids = origin_input_ids + output_ids
...@@ -260,6 +290,12 @@ class Req: ...@@ -260,6 +290,12 @@ class Req:
# The number of cached tokens, that were already cached in the KV cache # The number of cached tokens, that were already cached in the KV cache
self.cached_tokens = 0 self.cached_tokens = 0
def extend_image_inputs(self, image_inputs, vocab_size):
if self.image_inputs is None:
self.image_inputs = image_inputs
else:
self.image_inputs.merge(image_inputs, vocab_size)
# whether request reached finished condition # whether request reached finished condition
def finished(self) -> bool: def finished(self) -> bool:
return self.finished_reason is not None return self.finished_reason is not None
......
...@@ -559,12 +559,13 @@ class Scheduler: ...@@ -559,12 +559,13 @@ class Scheduler:
# Image inputs # Image inputs
if recv_req.image_inputs is not None: if recv_req.image_inputs is not None:
req.image_inputs = ImageInputs.from_dict( image_inputs = ImageInputs.from_dict(
recv_req.image_inputs, self.model_config.vocab_size recv_req.image_inputs, self.model_config.vocab_size
) )
req.origin_input_ids = self.pad_input_ids_func( req.origin_input_ids = self.pad_input_ids_func(
req.origin_input_ids_unpadded, req.image_inputs req.origin_input_ids, image_inputs
) )
req.extend_image_inputs(image_inputs, self.model_config.vocab_size)
if len(req.origin_input_ids) > self.max_req_input_len: if len(req.origin_input_ids) > self.max_req_input_len:
req.finished_reason = FINISH_ABORT( req.finished_reason = FINISH_ABORT(
......
...@@ -41,16 +41,27 @@ class Session: ...@@ -41,16 +41,27 @@ class Session:
] ]
+ req.input_ids + req.input_ids
) )
input_ids_unpadded = (
self.reqs[-1].origin_input_ids_unpadded
+ self.reqs[-1].output_ids[
: self.reqs[-1].sampling_params.max_new_tokens
]
+ req.input_ids
)
else: else:
input_ids = req.input_ids input_ids = req.input_ids
input_ids_unpadded = req.input_ids
new_req = Req( new_req = Req(
req.rid, rid=req.rid,
None, origin_input_text=None,
input_ids, origin_input_ids=input_ids,
req.sampling_params, origin_input_ids_unpadded=input_ids_unpadded,
sampling_params=req.sampling_params,
lora_path=req.lora_path, lora_path=req.lora_path,
session_id=self.session_id, session_id=self.session_id,
) )
if len(self.reqs) > 0:
new_req.image_inputs = self.reqs[-1].image_inputs
new_req.tokenizer = tokenizer new_req.tokenizer = tokenizer
if req.session_rid is not None and len(self.reqs) == 0: if req.session_rid is not None and len(self.reqs) == 0:
new_req.finished_reason = FINISH_ABORT( new_req.finished_reason = FINISH_ABORT(
......
...@@ -49,7 +49,13 @@ class LlavaBaseForCausalLM(nn.Module): ...@@ -49,7 +49,13 @@ class LlavaBaseForCausalLM(nn.Module):
image_sizes, pad_values = image_inputs.image_sizes, image_inputs.pad_values image_sizes, pad_values = image_inputs.image_sizes, image_inputs.pad_values
# hardcode for spatial_unpad + anyres # hardcode for spatial_unpad + anyres
image_aspect_ratio = "anyres" if len(image_sizes) == 1 else "pad" if image_inputs.modalities is not None and (
"multi-images" in image_inputs.modalities
or "video" in image_inputs.modalities
):
image_aspect_ratio = "pad"
else:
image_aspect_ratio = "anyres"
offset_list = [] offset_list = []
for image_s in image_sizes: for image_s in image_sizes:
if len(image_sizes) > 16: if len(image_sizes) > 16:
......
...@@ -36,6 +36,7 @@ suites = { ...@@ -36,6 +36,7 @@ suites = {
"test_triton_attention_backend.py", "test_triton_attention_backend.py",
"test_update_weights.py", "test_update_weights.py",
"test_vision_openai_server.py", "test_vision_openai_server.py",
"test_session_control.py",
], ],
"sampling/penaltylib": glob.glob( "sampling/penaltylib": glob.glob(
"sampling/penaltylib/**/test_*.py", recursive=True "sampling/penaltylib/**/test_*.py", recursive=True
......
""" """
Usage: Usage:
python3 -m unittest test_session_control.TestSessionControl.test_session_control python3 -m unittest test_session_control.TestSessionControl.test_session_control
python3 -m unittest test_session_control.TestSessionControl.test_session_control_vlm python3 -m unittest test_session_control.TestSessionControlVision.test_session_control
""" """
import unittest import unittest
...@@ -61,6 +61,8 @@ class TestSessionControl(unittest.TestCase): ...@@ -61,6 +61,8 @@ class TestSessionControl(unittest.TestCase):
"max_new_tokens": ( "max_new_tokens": (
16 if i > 0 else 0 16 if i > 0 else 0
), # prefill only for the first chunk ), # prefill only for the first chunk
"no_stop_trim": True,
"skip_special_tokens": False,
}, },
}, },
).json() ).json()
...@@ -79,6 +81,8 @@ class TestSessionControl(unittest.TestCase): ...@@ -79,6 +81,8 @@ class TestSessionControl(unittest.TestCase):
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": 16, "max_new_tokens": 16,
"no_stop_trim": True,
"skip_special_tokens": False,
}, },
}, },
).json() ).json()
...@@ -93,6 +97,8 @@ class TestSessionControl(unittest.TestCase): ...@@ -93,6 +97,8 @@ class TestSessionControl(unittest.TestCase):
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": 16, "max_new_tokens": 16,
"no_stop_trim": True,
"skip_special_tokens": False,
}, },
}, },
).json() ).json()
...@@ -113,6 +119,8 @@ class TestSessionControl(unittest.TestCase): ...@@ -113,6 +119,8 @@ class TestSessionControl(unittest.TestCase):
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": 16, "max_new_tokens": 16,
"no_stop_trim": True,
"skip_special_tokens": False,
}, },
}, },
).json() ).json()
...@@ -133,13 +141,16 @@ class TestSessionControl(unittest.TestCase): ...@@ -133,13 +141,16 @@ class TestSessionControl(unittest.TestCase):
"max_new_tokens": ( "max_new_tokens": (
16 if i > 0 else 0 16 if i > 0 else 0
), # prefill only for the first chunk ), # prefill only for the first chunk
"no_stop_trim": True,
"skip_special_tokens": False,
}, },
}, },
).json() ).json()
if i > 0: if i > 0:
input_ids += tokenizer.encode(response["text"])[ output_ids = tokenizer.encode(response["text"])
1: if output_ids[0] == tokenizer.bos_token_id:
] # drop the bos token output_ids = output_ids[1:]
input_ids += output_ids
outputs_normal.append(response["text"]) outputs_normal.append(response["text"])
if i == 0: if i == 0:
input_ids_first_req = input_ids.copy() input_ids_first_req = input_ids.copy()
...@@ -152,6 +163,187 @@ class TestSessionControl(unittest.TestCase): ...@@ -152,6 +163,187 @@ class TestSessionControl(unittest.TestCase):
"sampling_params": { "sampling_params": {
"temperature": 0, "temperature": 0,
"max_new_tokens": 16, "max_new_tokens": 16,
"no_stop_trim": True,
"skip_special_tokens": False,
},
},
).json()
outputs_normal.append(response["text"])
print("outputs from chunked queries with session control:")
print(outputs_from_session)
print("outputs from normal queries:")
print(outputs_normal)
assert outputs_from_session == outputs_normal
class TestSessionControlVision(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = "lmms-lab/llava-onevision-qwen2-7b-ov"
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
# other_args={"--disable-radix"},
)
@classmethod
def tearDownClass(cls):
kill_child_process(cls.process.pid, include_self=True)
def test_session_control(self):
text_chunks = [
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n",
"<|im_start|>user\n<image>\nDescribe this image in a very short sentence.<|im_end|>\n<|im_start|>assistant\n",
"<|im_start|>user\n<image>\nIs this image same with the previous image? Answer yes or no.<|im_end|>\n<|im_start|>assistant\n",
"<|im_start|>user\n<image>\nIs this image same with the previous image? Answer yes or no.<|im_end|>\n<|im_start|>assistant\n",
]
image_chunks = [
"https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png",
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png",
]
assert len(text_chunks) == len(image_chunks) + 1
tokenizer = get_tokenizer(self.model)
text_input_ids = [tokenizer.encode(x) for x in text_chunks]
# 1. using session control
session_id = requests.post(
self.base_url + "/open_session",
json={"capacity_of_str_len": 1000},
).json()
rid = None
first_rid = None
outputs_from_session = []
for i in range(len(text_input_ids)):
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": text_input_ids[i],
"image_data": image_chunks[i - 1] if i > 0 else None,
"modalities": ["multi-images"],
"session": [session_id, rid],
"sampling_params": {
"temperature": 0,
"max_new_tokens": (
16 if i > 0 else 0
), # prefill only for the first chunk
"no_stop_trim": True,
"skip_special_tokens": False,
},
},
).json()
rid = response["meta_info"]["id"]
if i == 0:
first_rid = rid
if i > 0:
outputs_from_session.append(response["text"])
# backtrack to the first request and regenerate
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": text_input_ids[-1],
"image_data": image_chunks[-1:],
"modalities": ["multi-images"],
"session": [session_id, first_rid],
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
"no_stop_trim": True,
"skip_special_tokens": False,
},
},
).json()
outputs_from_session.append(response["text"])
# query with a non-existing rid (the last one should be disappeared becuase of backtrack), should see abort
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": text_input_ids[-1],
"image_data": image_chunks[-1:],
"modalities": ["multi-images"],
"session": [session_id, rid],
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
"no_stop_trim": True,
"skip_special_tokens": False,
},
},
).json()
assert response["meta_info"]["finish_reason"]["type"] == "abort"
ret = requests.post(
self.base_url + "/close_session",
json={"session_id": session_id},
)
assert ret.status_code == 200
# send a request to a closed session, should see abort
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": text_input_ids[-1],
"session": [session_id, first_rid],
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
"no_stop_trim": True,
"skip_special_tokens": False,
},
},
).json()
assert response["meta_info"]["finish_reason"]["type"] == "abort"
# 2. not use session control
input_ids_first_req = None
input_ids = []
outputs_normal = []
for i in range(len(text_input_ids)):
input_ids += text_input_ids[i]
image_data = image_chunks[:i] if i > 0 else None
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": input_ids,
"image_data": image_data,
"modalities": ["multi-images"],
"sampling_params": {
"temperature": 0,
"max_new_tokens": (
16 if i > 0 else 0
), # prefill only for the first chunk
"no_stop_trim": True,
"skip_special_tokens": False,
},
},
).json()
if i > 0:
output_ids = tokenizer.encode(response["text"])
if output_ids[0] == tokenizer.bos_token_id:
output_ids = output_ids[1:]
input_ids += output_ids
outputs_normal.append(response["text"])
if i == 0:
input_ids_first_req = input_ids.copy()
input_ids_first_req += text_input_ids[-1]
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": input_ids_first_req,
"image_data": image_chunks[-1:],
"modalities": ["multi-images"],
"sampling_params": {
"temperature": 0,
"max_new_tokens": 16,
"no_stop_trim": True,
"skip_special_tokens": False,
}, },
}, },
).json() ).json()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment