Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
662ecd93
Unverified
Commit
662ecd93
authored
Sep 09, 2024
by
Kaichen Zhang - NTU
Committed by
GitHub
Sep 09, 2024
Browse files
[Feat] Add modalities for vision server when handling pixel values for llava (#1346)
parent
8e6bdf85
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
40 additions
and
2 deletions
+40
-2
examples/runtime/llava_onevision/http_llava_onevision_test.py
...ples/runtime/llava_onevision/http_llava_onevision_test.py
+3
-0
python/sglang/srt/conversation.py
python/sglang/srt/conversation.py
+3
-0
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+4
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+1
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+5
-0
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+2
-0
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+2
-0
python/sglang/srt/models/llava.py
python/sglang/srt/models/llava.py
+9
-2
python/sglang/srt/openai_api/adapter.py
python/sglang/srt/openai_api/adapter.py
+7
-0
python/sglang/srt/openai_api/protocol.py
python/sglang/srt/openai_api/protocol.py
+1
-0
test/srt/test_vision_openai_server.py
test/srt/test_vision_openai_server.py
+3
-0
No files found.
examples/runtime/llava_onevision/http_llava_onevision_test.py
View file @
662ecd93
...
...
@@ -93,12 +93,14 @@ def multi_image_stream_request_test(client):
"image_url"
:
{
"url"
:
"https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
},
"modalities"
:
"multi-images"
,
},
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
},
"modalities"
:
"multi-images"
,
},
{
"type"
:
"text"
,
...
...
@@ -218,6 +220,7 @@ def prepare_video_messages(video_path):
frame_format
=
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
"data:image/jpeg;base64,{}"
},
"modalities"
:
"video"
,
}
for
base64_frame
in
base64_frames
:
...
...
python/sglang/srt/conversation.py
View file @
662ecd93
...
...
@@ -71,6 +71,7 @@ class Conversation:
# Stop criteria (the default one is EOS token)
stop_str
:
Union
[
str
,
List
[
str
]]
=
None
image_data
:
Optional
[
List
[
str
]]
=
None
modalities
:
Optional
[
List
[
str
]]
=
None
def
get_prompt
(
self
)
->
str
:
"""Get the prompt for generation."""
...
...
@@ -379,6 +380,7 @@ def generate_chat_conv(
sep2
=
conv
.
sep2
,
stop_str
=
conv
.
stop_str
,
image_data
=
[],
modalities
=
[],
)
if
isinstance
(
request
.
messages
,
str
):
...
...
@@ -408,6 +410,7 @@ def generate_chat_conv(
for
content
in
message
.
content
:
if
content
.
type
==
"image_url"
:
num_image_url
+=
1
conv
.
modalities
.
append
(
content
.
modalities
)
if
num_image_url
>
1
:
image_token
=
"<image>"
else
:
...
...
python/sglang/srt/managers/io_struct.py
View file @
662ecd93
...
...
@@ -50,6 +50,8 @@ class GenerateReqInput:
return_text_in_logprobs
:
bool
=
False
# Whether to stream output.
stream
:
bool
=
False
# The modalities of the image data [image, multi-images, video]
modalities
:
Optional
[
List
[
str
]]
=
None
def
post_init
(
self
):
if
(
self
.
text
is
None
and
self
.
input_ids
is
None
)
or
(
...
...
@@ -177,6 +179,8 @@ class TokenizedGenerateReqInput:
top_logprobs_num
:
int
# Whether to stream output
stream
:
bool
# Modalities of the input images
modalites
:
Optional
[
List
[
str
]]
=
None
@
dataclass
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
662ecd93
...
...
@@ -130,6 +130,7 @@ class Req:
self
.
image_sizes
=
None
self
.
image_offsets
=
None
self
.
pad_value
=
None
self
.
modalities
=
None
# Prefix info
self
.
extend_input_len
=
0
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
662ecd93
...
...
@@ -188,6 +188,7 @@ class TokenizerManager:
pixel_values
,
image_hashes
,
image_sizes
=
await
self
.
_get_pixel_values
(
obj
.
image_data
if
not_use_index
else
obj
.
image_data
[
index
]
)
modalities
=
obj
.
modalities
return_logprob
=
(
obj
.
return_logprob
if
not_use_index
else
obj
.
return_logprob
[
index
]
)
...
...
@@ -243,6 +244,7 @@ class TokenizerManager:
pixel_values
,
image_hashes
,
image_sizes
=
await
self
.
_get_pixel_values
(
obj
.
image_data
[
0
]
)
modalities
=
obj
.
modalities
return_logprob
=
obj
.
return_logprob
[
0
]
logprob_start_len
=
obj
.
logprob_start_len
[
0
]
top_logprobs_num
=
obj
.
top_logprobs_num
[
0
]
...
...
@@ -263,6 +265,7 @@ class TokenizerManager:
logprob_start_len
,
top_logprobs_num
,
obj
.
stream
,
modalities
,
)
else
:
# is embedding
tokenized_obj
=
TokenizedEmbeddingReqInput
(
...
...
@@ -346,6 +349,7 @@ class TokenizerManager:
pixel_values
,
image_hashes
,
image_sizes
=
(
await
self
.
_get_pixel_values
(
obj
.
image_data
[
index
])
)
modalities
=
obj
.
modalities
tokenized_obj
=
TokenizedGenerateReqInput
(
rid
,
...
...
@@ -359,6 +363,7 @@ class TokenizerManager:
obj
.
logprob_start_len
[
index
],
obj
.
top_logprobs_num
[
index
],
obj
.
stream
,
modalities
,
)
else
:
tokenized_obj
=
TokenizedEmbeddingReqInput
(
...
...
python/sglang/srt/managers/tp_worker.py
View file @
662ecd93
...
...
@@ -358,6 +358,8 @@ class ModelTpServer:
req
.
pixel_values
,
req
.
image_sizes
,
)
# Only when pixel values is not None we have modalities
req
.
modalities
=
recv_req
.
modalites
req
.
return_logprob
=
recv_req
.
return_logprob
req
.
logprob_start_len
=
recv_req
.
logprob_start_len
req
.
top_logprobs_num
=
recv_req
.
top_logprobs_num
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
662ecd93
...
...
@@ -78,6 +78,7 @@ class InputMetadata:
pixel_values
:
List
[
torch
.
Tensor
]
=
None
image_sizes
:
List
[
List
[
List
[
int
]]]
=
None
image_offsets
:
List
[
List
[
int
]]
=
None
modalities
:
List
[
List
[
str
]]
=
None
# Trition attention backend
triton_max_seq_len
:
int
=
0
...
...
@@ -96,6 +97,7 @@ class InputMetadata:
self
.
pixel_values
=
[
r
.
pixel_values
for
r
in
reqs
]
self
.
image_sizes
=
[
r
.
image_sizes
for
r
in
reqs
]
self
.
image_offsets
=
[
r
.
image_offsets
for
r
in
reqs
]
self
.
modalities
=
[
r
.
modalities
for
r
in
reqs
]
def
compute_positions
(
self
,
batch
:
ScheduleBatch
):
position_ids_offsets
=
batch
.
position_ids_offsets
...
...
python/sglang/srt/models/llava.py
View file @
662ecd93
...
...
@@ -138,6 +138,12 @@ class LlavaBaseForCausalLM(nn.Module):
)
->
torch
.
Tensor
:
if
input_metadata
.
forward_mode
==
ForwardMode
.
EXTEND
:
bs
=
input_metadata
.
batch_size
# Got List[List[str]] extend it to List[str]
# The length of the List should be equal to batch size
modalities_list
=
[]
for
modalities
in
input_metadata
.
modalities
:
if
modalities
is
not
None
:
modalities_list
.
extend
(
modalities
)
# Embed text inputs
input_embeds
=
self
.
language_model
.
model
.
embed_tokens
(
input_ids
)
...
...
@@ -179,7 +185,7 @@ class LlavaBaseForCausalLM(nn.Module):
new_image_features
=
[]
height
=
width
=
self
.
num_patches_per_side
for
image_idx
,
image_feature
in
enumerate
(
image_features
):
if
len
(
image_sizes
[
image_idx
]
)
==
1
:
if
modalities_list
[
image_idx
]
==
1
:
image_aspect_ratio
=
(
self
.
config
.
image_aspect_ratio
)
# single image
...
...
@@ -191,6 +197,7 @@ class LlavaBaseForCausalLM(nn.Module):
if
(
image_feature
.
shape
[
0
]
>
1
and
"anyres"
in
image_aspect_ratio
and
modalities_list
[
image_idx
]
==
"image"
):
base_image_feature
=
image_feature
[
0
]
image_feature
=
image_feature
[
1
:]
...
...
@@ -290,7 +297,7 @@ class LlavaBaseForCausalLM(nn.Module):
)
image_feature
=
image_feature
.
unsqueeze
(
0
)
else
:
if
image_feature
.
shape
[
0
]
>
16
:
# video
if
modalities_list
[
image_idx
]
==
"video"
:
# video
# 2x2 pooling
num_of_frames
=
image_feature
.
shape
[
0
]
image_feature
=
image_feature
.
view
(
...
...
python/sglang/srt/openai_api/adapter.py
View file @
662ecd93
...
...
@@ -832,6 +832,7 @@ def v1_chat_generate_request(
return_logprobs
=
[]
logprob_start_lens
=
[]
top_logprobs_nums
=
[]
modalities_list
=
[]
# NOTE: with openai API, the prompt's logprobs are always not computed
...
...
@@ -864,10 +865,12 @@ def v1_chat_generate_request(
)
stop
=
request
.
stop
image_data
=
None
modalities
=
[]
else
:
conv
=
generate_chat_conv
(
request
,
chat_template_name
)
prompt
=
conv
.
get_prompt
()
image_data
=
conv
.
image_data
modalities
=
conv
.
modalities
stop
=
conv
.
stop_str
or
[]
if
request
.
stop
:
if
isinstance
(
request
.
stop
,
str
):
...
...
@@ -880,6 +883,7 @@ def v1_chat_generate_request(
prompt_ids
=
request
.
messages
stop
=
request
.
stop
image_data
=
None
modalities
=
[]
input_ids
.
append
(
prompt_ids
)
return_logprobs
.
append
(
request
.
logprobs
)
logprob_start_lens
.
append
(
-
1
)
...
...
@@ -901,6 +905,7 @@ def v1_chat_generate_request(
}
)
image_data_list
.
append
(
image_data
)
modalities_list
.
extend
(
modalities
)
if
len
(
all_requests
)
==
1
:
input_ids
=
input_ids
[
0
]
if
isinstance
(
input_ids
,
str
):
...
...
@@ -912,6 +917,7 @@ def v1_chat_generate_request(
return_logprobs
=
return_logprobs
[
0
]
logprob_start_lens
=
logprob_start_lens
[
0
]
top_logprobs_nums
=
top_logprobs_nums
[
0
]
modalities_list
=
modalities_list
[:
1
]
else
:
if
isinstance
(
input_ids
[
0
],
str
):
prompt_kwargs
=
{
"text"
:
input_ids
}
...
...
@@ -928,6 +934,7 @@ def v1_chat_generate_request(
stream
=
all_requests
[
0
].
stream
,
return_text_in_logprobs
=
True
,
rid
=
request_ids
,
modalities
=
modalities_list
,
)
if
len
(
all_requests
)
==
1
:
return
adapted_request
,
all_requests
[
0
]
...
...
python/sglang/srt/openai_api/protocol.py
View file @
662ecd93
...
...
@@ -213,6 +213,7 @@ class ChatCompletionMessageContentImageURL(BaseModel):
class
ChatCompletionMessageContentImagePart
(
BaseModel
):
type
:
Literal
[
"image_url"
]
image_url
:
ChatCompletionMessageContentImageURL
modalities
:
Optional
[
Literal
[
"image"
,
"multi-images"
,
"video"
]]
=
"image"
ChatCompletionMessageContentPart
=
Union
[
...
...
test/srt/test_vision_openai_server.py
View file @
662ecd93
...
...
@@ -140,12 +140,14 @@ class TestOpenAIVisionServer(unittest.TestCase):
"image_url"
:
{
"url"
:
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
},
"modalities"
:
"multi-images"
,
},
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
"https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
},
"modalities"
:
"multi-images"
,
},
{
"type"
:
"text"
,
...
...
@@ -192,6 +194,7 @@ class TestOpenAIVisionServer(unittest.TestCase):
frame_format
=
{
"type"
:
"image_url"
,
"image_url"
:
{
"url"
:
"data:image/jpeg;base64,{}"
},
"modalities"
:
"video"
,
}
for
base64_frame
in
base64_frames
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment