Unverified Commit 662ecd93 authored by Kaichen Zhang - NTU's avatar Kaichen Zhang - NTU Committed by GitHub
Browse files

[Feat] Add modalities for vision server when handling pixel values for llava (#1346)

parent 8e6bdf85
...@@ -93,12 +93,14 @@ def multi_image_stream_request_test(client): ...@@ -93,12 +93,14 @@ def multi_image_stream_request_test(client):
"image_url": { "image_url": {
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png" "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
}, },
"modalities": "multi-images",
}, },
{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png" "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
}, },
"modalities": "multi-images",
}, },
{ {
"type": "text", "type": "text",
...@@ -218,6 +220,7 @@ def prepare_video_messages(video_path): ...@@ -218,6 +220,7 @@ def prepare_video_messages(video_path):
frame_format = { frame_format = {
"type": "image_url", "type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,{}"}, "image_url": {"url": "data:image/jpeg;base64,{}"},
"modalities": "video",
} }
for base64_frame in base64_frames: for base64_frame in base64_frames:
......
...@@ -71,6 +71,7 @@ class Conversation: ...@@ -71,6 +71,7 @@ class Conversation:
# Stop criteria (the default one is EOS token) # Stop criteria (the default one is EOS token)
stop_str: Union[str, List[str]] = None stop_str: Union[str, List[str]] = None
image_data: Optional[List[str]] = None image_data: Optional[List[str]] = None
modalities: Optional[List[str]] = None
def get_prompt(self) -> str: def get_prompt(self) -> str:
"""Get the prompt for generation.""" """Get the prompt for generation."""
...@@ -379,6 +380,7 @@ def generate_chat_conv( ...@@ -379,6 +380,7 @@ def generate_chat_conv(
sep2=conv.sep2, sep2=conv.sep2,
stop_str=conv.stop_str, stop_str=conv.stop_str,
image_data=[], image_data=[],
modalities=[],
) )
if isinstance(request.messages, str): if isinstance(request.messages, str):
...@@ -408,6 +410,7 @@ def generate_chat_conv( ...@@ -408,6 +410,7 @@ def generate_chat_conv(
for content in message.content: for content in message.content:
if content.type == "image_url": if content.type == "image_url":
num_image_url += 1 num_image_url += 1
conv.modalities.append(content.modalities)
if num_image_url > 1: if num_image_url > 1:
image_token = "<image>" image_token = "<image>"
else: else:
......
...@@ -50,6 +50,8 @@ class GenerateReqInput: ...@@ -50,6 +50,8 @@ class GenerateReqInput:
return_text_in_logprobs: bool = False return_text_in_logprobs: bool = False
# Whether to stream output. # Whether to stream output.
stream: bool = False stream: bool = False
# The modalities of the image data [image, multi-images, video]
modalities: Optional[List[str]] = None
def post_init(self): def post_init(self):
if (self.text is None and self.input_ids is None) or ( if (self.text is None and self.input_ids is None) or (
...@@ -177,6 +179,8 @@ class TokenizedGenerateReqInput: ...@@ -177,6 +179,8 @@ class TokenizedGenerateReqInput:
top_logprobs_num: int top_logprobs_num: int
# Whether to stream output # Whether to stream output
stream: bool stream: bool
# Modalities of the input images
modalites: Optional[List[str]] = None
@dataclass @dataclass
......
...@@ -130,6 +130,7 @@ class Req: ...@@ -130,6 +130,7 @@ class Req:
self.image_sizes = None self.image_sizes = None
self.image_offsets = None self.image_offsets = None
self.pad_value = None self.pad_value = None
self.modalities = None
# Prefix info # Prefix info
self.extend_input_len = 0 self.extend_input_len = 0
......
...@@ -188,6 +188,7 @@ class TokenizerManager: ...@@ -188,6 +188,7 @@ class TokenizerManager:
pixel_values, image_hashes, image_sizes = await self._get_pixel_values( pixel_values, image_hashes, image_sizes = await self._get_pixel_values(
obj.image_data if not_use_index else obj.image_data[index] obj.image_data if not_use_index else obj.image_data[index]
) )
modalities = obj.modalities
return_logprob = ( return_logprob = (
obj.return_logprob if not_use_index else obj.return_logprob[index] obj.return_logprob if not_use_index else obj.return_logprob[index]
) )
...@@ -243,6 +244,7 @@ class TokenizerManager: ...@@ -243,6 +244,7 @@ class TokenizerManager:
pixel_values, image_hashes, image_sizes = await self._get_pixel_values( pixel_values, image_hashes, image_sizes = await self._get_pixel_values(
obj.image_data[0] obj.image_data[0]
) )
modalities = obj.modalities
return_logprob = obj.return_logprob[0] return_logprob = obj.return_logprob[0]
logprob_start_len = obj.logprob_start_len[0] logprob_start_len = obj.logprob_start_len[0]
top_logprobs_num = obj.top_logprobs_num[0] top_logprobs_num = obj.top_logprobs_num[0]
...@@ -263,6 +265,7 @@ class TokenizerManager: ...@@ -263,6 +265,7 @@ class TokenizerManager:
logprob_start_len, logprob_start_len,
top_logprobs_num, top_logprobs_num,
obj.stream, obj.stream,
modalities,
) )
else: # is embedding else: # is embedding
tokenized_obj = TokenizedEmbeddingReqInput( tokenized_obj = TokenizedEmbeddingReqInput(
...@@ -346,6 +349,7 @@ class TokenizerManager: ...@@ -346,6 +349,7 @@ class TokenizerManager:
pixel_values, image_hashes, image_sizes = ( pixel_values, image_hashes, image_sizes = (
await self._get_pixel_values(obj.image_data[index]) await self._get_pixel_values(obj.image_data[index])
) )
modalities = obj.modalities
tokenized_obj = TokenizedGenerateReqInput( tokenized_obj = TokenizedGenerateReqInput(
rid, rid,
...@@ -359,6 +363,7 @@ class TokenizerManager: ...@@ -359,6 +363,7 @@ class TokenizerManager:
obj.logprob_start_len[index], obj.logprob_start_len[index],
obj.top_logprobs_num[index], obj.top_logprobs_num[index],
obj.stream, obj.stream,
modalities,
) )
else: else:
tokenized_obj = TokenizedEmbeddingReqInput( tokenized_obj = TokenizedEmbeddingReqInput(
......
...@@ -358,6 +358,8 @@ class ModelTpServer: ...@@ -358,6 +358,8 @@ class ModelTpServer:
req.pixel_values, req.pixel_values,
req.image_sizes, req.image_sizes,
) )
# Only when pixel values is not None we have modalities
req.modalities = recv_req.modalites
req.return_logprob = recv_req.return_logprob req.return_logprob = recv_req.return_logprob
req.logprob_start_len = recv_req.logprob_start_len req.logprob_start_len = recv_req.logprob_start_len
req.top_logprobs_num = recv_req.top_logprobs_num req.top_logprobs_num = recv_req.top_logprobs_num
......
...@@ -78,6 +78,7 @@ class InputMetadata: ...@@ -78,6 +78,7 @@ class InputMetadata:
pixel_values: List[torch.Tensor] = None pixel_values: List[torch.Tensor] = None
image_sizes: List[List[List[int]]] = None image_sizes: List[List[List[int]]] = None
image_offsets: List[List[int]] = None image_offsets: List[List[int]] = None
modalities: List[List[str]] = None
# Trition attention backend # Trition attention backend
triton_max_seq_len: int = 0 triton_max_seq_len: int = 0
...@@ -96,6 +97,7 @@ class InputMetadata: ...@@ -96,6 +97,7 @@ class InputMetadata:
self.pixel_values = [r.pixel_values for r in reqs] self.pixel_values = [r.pixel_values for r in reqs]
self.image_sizes = [r.image_sizes for r in reqs] self.image_sizes = [r.image_sizes for r in reqs]
self.image_offsets = [r.image_offsets for r in reqs] self.image_offsets = [r.image_offsets for r in reqs]
self.modalities = [r.modalities for r in reqs]
def compute_positions(self, batch: ScheduleBatch): def compute_positions(self, batch: ScheduleBatch):
position_ids_offsets = batch.position_ids_offsets position_ids_offsets = batch.position_ids_offsets
......
...@@ -138,6 +138,12 @@ class LlavaBaseForCausalLM(nn.Module): ...@@ -138,6 +138,12 @@ class LlavaBaseForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
if input_metadata.forward_mode == ForwardMode.EXTEND: if input_metadata.forward_mode == ForwardMode.EXTEND:
bs = input_metadata.batch_size bs = input_metadata.batch_size
# Got List[List[str]] extend it to List[str]
# The length of the List should be equal to batch size
modalities_list = []
for modalities in input_metadata.modalities:
if modalities is not None:
modalities_list.extend(modalities)
# Embed text inputs # Embed text inputs
input_embeds = self.language_model.model.embed_tokens(input_ids) input_embeds = self.language_model.model.embed_tokens(input_ids)
...@@ -179,7 +185,7 @@ class LlavaBaseForCausalLM(nn.Module): ...@@ -179,7 +185,7 @@ class LlavaBaseForCausalLM(nn.Module):
new_image_features = [] new_image_features = []
height = width = self.num_patches_per_side height = width = self.num_patches_per_side
for image_idx, image_feature in enumerate(image_features): for image_idx, image_feature in enumerate(image_features):
if len(image_sizes[image_idx]) == 1: if modalities_list[image_idx] == 1:
image_aspect_ratio = ( image_aspect_ratio = (
self.config.image_aspect_ratio self.config.image_aspect_ratio
) # single image ) # single image
...@@ -191,6 +197,7 @@ class LlavaBaseForCausalLM(nn.Module): ...@@ -191,6 +197,7 @@ class LlavaBaseForCausalLM(nn.Module):
if ( if (
image_feature.shape[0] > 1 image_feature.shape[0] > 1
and "anyres" in image_aspect_ratio and "anyres" in image_aspect_ratio
and modalities_list[image_idx] == "image"
): ):
base_image_feature = image_feature[0] base_image_feature = image_feature[0]
image_feature = image_feature[1:] image_feature = image_feature[1:]
...@@ -290,7 +297,7 @@ class LlavaBaseForCausalLM(nn.Module): ...@@ -290,7 +297,7 @@ class LlavaBaseForCausalLM(nn.Module):
) )
image_feature = image_feature.unsqueeze(0) image_feature = image_feature.unsqueeze(0)
else: else:
if image_feature.shape[0] > 16: # video if modalities_list[image_idx] == "video": # video
# 2x2 pooling # 2x2 pooling
num_of_frames = image_feature.shape[0] num_of_frames = image_feature.shape[0]
image_feature = image_feature.view( image_feature = image_feature.view(
......
...@@ -832,6 +832,7 @@ def v1_chat_generate_request( ...@@ -832,6 +832,7 @@ def v1_chat_generate_request(
return_logprobs = [] return_logprobs = []
logprob_start_lens = [] logprob_start_lens = []
top_logprobs_nums = [] top_logprobs_nums = []
modalities_list = []
# NOTE: with openai API, the prompt's logprobs are always not computed # NOTE: with openai API, the prompt's logprobs are always not computed
...@@ -864,10 +865,12 @@ def v1_chat_generate_request( ...@@ -864,10 +865,12 @@ def v1_chat_generate_request(
) )
stop = request.stop stop = request.stop
image_data = None image_data = None
modalities = []
else: else:
conv = generate_chat_conv(request, chat_template_name) conv = generate_chat_conv(request, chat_template_name)
prompt = conv.get_prompt() prompt = conv.get_prompt()
image_data = conv.image_data image_data = conv.image_data
modalities = conv.modalities
stop = conv.stop_str or [] stop = conv.stop_str or []
if request.stop: if request.stop:
if isinstance(request.stop, str): if isinstance(request.stop, str):
...@@ -880,6 +883,7 @@ def v1_chat_generate_request( ...@@ -880,6 +883,7 @@ def v1_chat_generate_request(
prompt_ids = request.messages prompt_ids = request.messages
stop = request.stop stop = request.stop
image_data = None image_data = None
modalities = []
input_ids.append(prompt_ids) input_ids.append(prompt_ids)
return_logprobs.append(request.logprobs) return_logprobs.append(request.logprobs)
logprob_start_lens.append(-1) logprob_start_lens.append(-1)
...@@ -901,6 +905,7 @@ def v1_chat_generate_request( ...@@ -901,6 +905,7 @@ def v1_chat_generate_request(
} }
) )
image_data_list.append(image_data) image_data_list.append(image_data)
modalities_list.extend(modalities)
if len(all_requests) == 1: if len(all_requests) == 1:
input_ids = input_ids[0] input_ids = input_ids[0]
if isinstance(input_ids, str): if isinstance(input_ids, str):
...@@ -912,6 +917,7 @@ def v1_chat_generate_request( ...@@ -912,6 +917,7 @@ def v1_chat_generate_request(
return_logprobs = return_logprobs[0] return_logprobs = return_logprobs[0]
logprob_start_lens = logprob_start_lens[0] logprob_start_lens = logprob_start_lens[0]
top_logprobs_nums = top_logprobs_nums[0] top_logprobs_nums = top_logprobs_nums[0]
modalities_list = modalities_list[:1]
else: else:
if isinstance(input_ids[0], str): if isinstance(input_ids[0], str):
prompt_kwargs = {"text": input_ids} prompt_kwargs = {"text": input_ids}
...@@ -928,6 +934,7 @@ def v1_chat_generate_request( ...@@ -928,6 +934,7 @@ def v1_chat_generate_request(
stream=all_requests[0].stream, stream=all_requests[0].stream,
return_text_in_logprobs=True, return_text_in_logprobs=True,
rid=request_ids, rid=request_ids,
modalities=modalities_list,
) )
if len(all_requests) == 1: if len(all_requests) == 1:
return adapted_request, all_requests[0] return adapted_request, all_requests[0]
......
...@@ -213,6 +213,7 @@ class ChatCompletionMessageContentImageURL(BaseModel): ...@@ -213,6 +213,7 @@ class ChatCompletionMessageContentImageURL(BaseModel):
class ChatCompletionMessageContentImagePart(BaseModel): class ChatCompletionMessageContentImagePart(BaseModel):
type: Literal["image_url"] type: Literal["image_url"]
image_url: ChatCompletionMessageContentImageURL image_url: ChatCompletionMessageContentImageURL
modalities: Optional[Literal["image", "multi-images", "video"]] = "image"
ChatCompletionMessageContentPart = Union[ ChatCompletionMessageContentPart = Union[
......
...@@ -140,12 +140,14 @@ class TestOpenAIVisionServer(unittest.TestCase): ...@@ -140,12 +140,14 @@ class TestOpenAIVisionServer(unittest.TestCase):
"image_url": { "image_url": {
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png" "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
}, },
"modalities": "multi-images",
}, },
{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png" "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
}, },
"modalities": "multi-images",
}, },
{ {
"type": "text", "type": "text",
...@@ -192,6 +194,7 @@ class TestOpenAIVisionServer(unittest.TestCase): ...@@ -192,6 +194,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
frame_format = { frame_format = {
"type": "image_url", "type": "image_url",
"image_url": {"url": "data:image/jpeg;base64,{}"}, "image_url": {"url": "data:image/jpeg;base64,{}"},
"modalities": "video",
} }
for base64_frame in base64_frames: for base64_frame in base64_frames:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment