Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
fd9ad817
Unverified
Commit
fd9ad817
authored
Sep 28, 2024
by
Liangsheng Yin
Committed by
GitHub
Sep 29, 2024
Browse files
Organize image inputs (#1531)
parent
e165a9fc
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
121 additions
and
132 deletions
+121
-132
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+2
-8
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+37
-14
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+16
-21
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+10
-22
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+3
-10
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-14
python/sglang/srt/models/llava.py
python/sglang/srt/models/llava.py
+30
-24
python/sglang/srt/models/llavavid.py
python/sglang/srt/models/llavavid.py
+22
-19
No files found.
python/sglang/srt/managers/io_struct.py
View file @
fd9ad817
...
...
@@ -172,12 +172,8 @@ class TokenizedGenerateReqInput:
input_text
:
str
# The input token ids
input_ids
:
List
[
int
]
# The pixel values for input images
pixel_values
:
List
[
float
]
# The hash values of input images
image_hashes
:
List
[
int
]
# The image sizes
image_sizes
:
List
[
List
[
int
]]
# The image input
image_inputs
:
dict
# The sampling parameters
sampling_params
:
SamplingParams
# Whether to return the logprobs
...
...
@@ -188,8 +184,6 @@ class TokenizedGenerateReqInput:
top_logprobs_num
:
int
# Whether to stream output
stream
:
bool
# Modalities of the input images
modalites
:
Optional
[
List
[
str
]]
=
None
# LoRA related
lora_path
:
Optional
[
str
]
=
None
# None means just use the base model
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
fd9ad817
...
...
@@ -102,6 +102,39 @@ class FINISH_ABORT(BaseFinishReason):
}
@
dataclass
class
ImageInputs
:
pixel_values
:
torch
.
Tensor
image_hash
:
int
image_sizes
:
Optional
[
list
]
=
None
image_offsets
:
Optional
[
list
]
=
None
pad_values
:
Optional
[
list
]
=
None
modalities
:
Optional
[
list
]
=
None
image_embeds
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
aspect_ratio_ids
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
aspect_ratio_mask
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
@
staticmethod
def
from_dict
(
obj
,
vocab_size
):
# Use image hash as fake token_ids, which is then used for prefix matching
ret
=
ImageInputs
(
pixel_values
=
obj
[
"pixel_values"
],
image_hash
=
hash
(
tuple
(
obj
[
"image_hashes"
])),
)
image_hash
=
ret
.
image_hash
ret
.
pad_values
=
[
(
image_hash
)
%
vocab_size
,
(
image_hash
>>
16
)
%
vocab_size
,
(
image_hash
>>
32
)
%
vocab_size
,
(
image_hash
>>
64
)
%
vocab_size
,
]
ret
.
image_sizes
=
obj
[
"image_sizes"
]
# Only when pixel values is not None we have modalities
ret
.
modalities
=
obj
[
"modalities"
]
return
ret
class
Req
:
"""Store all inforamtion of a request."""
...
...
@@ -147,11 +180,7 @@ class Req:
self
.
completion_tokens_wo_jump_forward
=
0
# For vision inputs
self
.
pixel_values
=
None
self
.
image_sizes
=
None
self
.
image_offsets
=
None
self
.
pad_value
=
None
self
.
modalities
=
None
self
.
image_inputs
:
Optional
[
ImageInputs
]
=
None
# Prefix info
self
.
prefix_indices
=
[]
...
...
@@ -654,15 +683,9 @@ class ScheduleBatch:
self
.
tree_cache
.
cache_finished_req
(
req
,
cur_all_ids
)
# re-applying image padding
if
req
.
pixel_values
is
not
None
:
(
req
.
origin_input_ids
,
req
.
image_offsets
,
)
=
model_runner
.
model
.
pad_input_ids
(
req
.
origin_input_ids_unpadded
,
req
.
pad_value
,
req
.
pixel_values
,
req
.
image_sizes
,
if
req
.
image_inputs
is
not
None
:
req
.
origin_input_ids
=
model_runner
.
model
.
pad_input_ids
(
req
.
origin_input_ids_unpadded
,
req
.
image_inputs
)
jump_forward_reqs
.
append
(
req
)
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
fd9ad817
...
...
@@ -194,10 +194,9 @@ class TokenizerManager:
)
if
self
.
is_generation
:
pixel_values
,
image_hashes
,
image_size
s
=
await
self
.
_get_
pixel_value
s
(
obj
.
image_data
if
not_use_index
else
obj
.
image_data
[
index
]
image_input
s
=
await
self
.
_get_
image_input
s
(
obj
,
obj
.
image_data
if
not_use_index
else
obj
.
image_data
[
index
]
)
modalities
=
obj
.
modalities
return_logprob
=
(
obj
.
return_logprob
if
not_use_index
else
obj
.
return_logprob
[
index
]
)
...
...
@@ -248,10 +247,7 @@ class TokenizerManager:
sampling_params
=
SamplingParams
(
**
obj
.
sampling_params
[
0
])
sampling_params
.
max_new_tokens
=
0
pixel_values
,
image_hashes
,
image_sizes
=
await
self
.
_get_pixel_values
(
obj
.
image_data
[
0
]
)
modalities
=
obj
.
modalities
image_inputs
=
await
self
.
_get_image_inputs
(
obj
,
obj
.
image_data
[
0
])
return_logprob
=
obj
.
return_logprob
[
0
]
logprob_start_len
=
obj
.
logprob_start_len
[
0
]
top_logprobs_num
=
obj
.
top_logprobs_num
[
0
]
...
...
@@ -262,15 +258,12 @@ class TokenizerManager:
rid
,
input_text
,
input_ids
,
pixel_values
,
image_hashes
,
image_sizes
,
image_inputs
,
sampling_params
,
return_logprob
,
logprob_start_len
,
top_logprobs_num
,
obj
.
stream
,
modalities
,
(
obj
.
lora_path
[
index
]
if
isinstance
(
obj
.
lora_path
,
list
)
...
...
@@ -369,24 +362,20 @@ class TokenizerManager:
sampling_params
=
self
.
_get_sampling_params
(
obj
.
sampling_params
[
index
])
if
self
.
is_generation
:
pixel_values
,
image_hashes
,
image_sizes
=
(
await
self
.
_get_pixel_values
(
obj
.
image_data
[
index
]
)
image_inputs
=
await
self
.
_get_image_inputs
(
obj
,
obj
.
image_data
[
index
]
)
modalities
=
obj
.
modalities
tokenized_obj
=
TokenizedGenerateReqInput
(
rid
,
input_text
,
input_ids
,
pixel_values
,
image_hashes
,
image_sizes
,
image_inputs
,
sampling_params
,
obj
.
return_logprob
[
index
],
obj
.
logprob_start_len
[
index
],
obj
.
top_logprobs_num
[
index
],
obj
.
stream
,
modalities
,
(
obj
.
lora_path
[
index
]
if
isinstance
(
obj
.
lora_path
,
list
)
...
...
@@ -697,10 +686,11 @@ class TokenizerManager:
)
return
top_logprobs
async
def
_get_
pixel_value
s
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]]):
async
def
_get_
image_input
s
(
self
,
obj
,
image_data
:
List
[
Union
[
str
,
bytes
]]):
if
not
image_data
:
return
None
,
None
,
None
return
None
# TODO: move this into a processor for each vision architecture
aspect_ratio
=
getattr
(
self
.
hf_config
,
"image_aspect_ratio"
,
None
)
grid_pinpoints
=
(
self
.
hf_config
.
image_grid_pinpoints
...
...
@@ -741,7 +731,12 @@ class TokenizerManager:
else
:
raise
ValueError
(
f
"Invalid image data:
{
image_data
}
"
)
return
pixel_values
,
image_hashes
,
image_sizes
return
{
"pixel_values"
:
pixel_values
,
"image_hashes"
:
image_hashes
,
"image_sizes"
:
image_sizes
,
"modalities"
:
obj
.
modalities
,
}
async
def
_process_single_image
(
self
,
image_data
:
Union
[
bytes
,
str
],
aspect_ratio
:
str
,
grid_pinpoints
:
str
...
...
python/sglang/srt/managers/tp_worker.py
View file @
fd9ad817
...
...
@@ -49,6 +49,7 @@ from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder
from
sglang.srt.managers.schedule_batch
import
(
FINISH_ABORT
,
BaseFinishReason
,
ImageInputs
,
Req
,
ScheduleBatch
,
)
...
...
@@ -340,29 +341,16 @@ class ModelTpServer:
req
=
Req
(
recv_req
.
rid
,
recv_req
.
input_text
,
recv_req
.
input_ids
)
req
.
tokenizer
=
self
.
tokenizer
req
.
sampling_params
=
recv_req
.
sampling_params
req
.
pixel_values
=
recv_req
.
pixel_values
if
req
.
pixel_values
is
not
None
:
# Use image hash as fake token_ids, which is then used
# for prefix matching
image_hash
=
hash
(
tuple
(
recv_req
.
image_hashes
))
req
.
pad_value
=
[
(
image_hash
)
%
self
.
model_config
.
vocab_size
,
(
image_hash
>>
16
)
%
self
.
model_config
.
vocab_size
,
(
image_hash
>>
32
)
%
self
.
model_config
.
vocab_size
,
(
image_hash
>>
64
)
%
self
.
model_config
.
vocab_size
,
]
req
.
image_sizes
=
recv_req
.
image_sizes
(
req
.
origin_input_ids
,
req
.
image_offsets
,
)
=
self
.
model_runner
.
model
.
pad_input_ids
(
req
.
origin_input_ids_unpadded
,
req
.
pad_value
,
req
.
pixel_values
,
req
.
image_sizes
,
# Image inputs
if
recv_req
.
image_inputs
is
not
None
:
req
.
image_inputs
=
ImageInputs
.
from_dict
(
recv_req
.
image_inputs
,
self
.
model_config
.
vocab_size
)
# Only when pixel values is not None we have modalities
req
.
modalities
=
recv_req
.
modalites
req
.
origin_input_ids
=
self
.
model_runner
.
model
.
pad_input_ids
(
req
.
origin_input_ids_unpadded
,
req
.
image_inputs
)
req
.
return_logprob
=
recv_req
.
return_logprob
req
.
top_logprobs_num
=
recv_req
.
top_logprobs_num
req
.
stream
=
recv_req
.
stream
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
fd9ad817
...
...
@@ -25,7 +25,7 @@ import torch
if
TYPE_CHECKING
:
from
sglang.srt.layers.attention_backend
import
AttentionBackend
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
ImageInputs
,
ScheduleBatch
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.model_executor.model_runner
import
ModelRunner
...
...
@@ -84,17 +84,10 @@ class InputMetadata:
extend_logprob_start_lens_cpu
:
List
[
int
]
=
None
# For multimodal
pixel_values
:
List
[
torch
.
Tensor
]
=
None
image_sizes
:
List
[
List
[
List
[
int
]]]
=
None
image_offsets
:
List
[
List
[
int
]]
=
None
modalities
:
List
[
List
[
str
]]
=
None
image_inputs
:
List
[
ImageInputs
]
=
None
def
init_multimuldal_info
(
self
,
batch
:
ScheduleBatch
):
reqs
=
batch
.
reqs
self
.
pixel_values
=
[
r
.
pixel_values
for
r
in
reqs
]
self
.
image_sizes
=
[
r
.
image_sizes
for
r
in
reqs
]
self
.
image_offsets
=
[
r
.
image_offsets
for
r
in
reqs
]
self
.
modalities
=
[
r
.
modalities
for
r
in
reqs
]
self
.
image_inputs
=
[
r
.
image_inputs
for
r
in
batch
.
reqs
]
def
compute_positions
(
self
,
batch
:
ScheduleBatch
):
if
self
.
forward_mode
.
is_decode
():
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
fd9ad817
...
...
@@ -498,23 +498,10 @@ class ModelRunner:
get_embedding
=
True
,
)
def
forward_extend_multi_modal
(
self
,
batch
:
ScheduleBatch
):
input_metadata
=
InputMetadata
.
from_schedule_batch
(
self
,
batch
)
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
,
input_metadata
.
pixel_values
,
input_metadata
.
image_sizes
,
input_metadata
.
image_offsets
,
)
def
forward
(
self
,
batch
:
ScheduleBatch
)
->
Tuple
[
LogitsProcessorOutput
]:
assert
batch
.
forward_mode
is
not
None
if
self
.
is_multimodal_model
and
batch
.
forward_mode
.
is_extend
():
return
self
.
forward_extend_multi_modal
(
batch
)
elif
batch
.
forward_mode
.
is_decode
():
if
batch
.
forward_mode
.
is_decode
():
return
self
.
forward_decode
(
batch
)
elif
batch
.
forward_mode
.
is_extend
():
return
self
.
forward_extend
(
batch
)
...
...
python/sglang/srt/models/llava.py
View file @
fd9ad817
...
...
@@ -35,25 +35,22 @@ from vllm.config import CacheConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.managers.schedule_batch
import
ImageInputs
from
sglang.srt.mm_utils
import
(
get_anyres_image_grid_shape
,
unpad_image
,
unpad_image_shape
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
,
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.models.llama
import
LlamaForCausalLM
from
sglang.srt.models.mistral
import
MistralForCausalLM
from
sglang.srt.models.qwen2
import
Qwen2ForCausalLM
class
LlavaBaseForCausalLM
(
nn
.
Module
):
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
pad_value
:
List
[
int
],
pixel_values
:
List
,
image_sizes
:
List
[
List
[
int
]],
):
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
ImageInputs
):
image_sizes
,
pad_values
=
image_inputs
.
image_sizes
,
image_inputs
.
pad_values
# hardcode for spatial_unpad + anyres
image_aspect_ratio
=
"anyres"
if
len
(
image_sizes
)
==
1
else
"pad"
offset_list
=
[]
...
...
@@ -92,8 +89,8 @@ class LlavaBaseForCausalLM(nn.Module):
new_w
=
int
(
new_w
//
times
)
new_image_feature_len
+=
new_h
*
(
new_w
+
1
)
pad_ids
=
pad_value
*
(
(
new_image_feature_len
+
len
(
pad_value
))
//
len
(
pad_value
)
pad_ids
=
pad_value
s
*
(
(
new_image_feature_len
+
len
(
pad_value
s
))
//
len
(
pad_value
s
)
)
# print("calculated new_image_feature_len: ", new_image_feature_len)
try
:
...
...
@@ -107,7 +104,9 @@ class LlavaBaseForCausalLM(nn.Module):
+
input_ids
[
offset
+
1
:]
)
offset_list
.
append
(
offset
)
return
input_ids
,
offset_list
image_inputs
.
image_offsets
=
offset_list
return
input_ids
def
encode_images
(
self
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
image_outputs
=
self
.
vision_tower
(
pixel_values
,
output_hidden_states
=
True
)
...
...
@@ -132,32 +131,39 @@ class LlavaBaseForCausalLM(nn.Module):
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
pixel_values
:
Optional
[
List
[
Optional
[
np
.
array
]]]
=
None
,
image_sizes
:
Optional
[
List
[
List
[
int
]]]
=
None
,
image_offsets
:
Optional
[
List
[
int
]]
=
None
,
)
->
torch
.
Tensor
:
image_inputs
=
input_metadata
.
image_inputs
if
input_metadata
.
forward_mode
.
is_extend
():
bs
=
input_metadata
.
batch_size
# Got List[List[str]] extend it to List[str]
# The length of the List should be equal to batch size
modalities_list
=
[]
for
modalities
in
input_metadata
.
modalities
:
if
modalities
is
not
None
:
modalities_list
.
extend
(
modalities
)
max_image_offset
=
[]
for
im
in
image_inputs
:
if
im
and
im
.
modalities
is
not
None
:
modalities_list
.
extend
(
im
.
modalities
)
if
im
and
im
.
image_offsets
is
not
None
:
max_image_offset
.
append
(
max
(
im
.
image_offsets
))
else
:
max_image_offset
.
append
(
-
1
)
# Embed text inputs
input_embeds
=
self
.
language_model
.
model
.
embed_tokens
(
input_ids
)
# Whether the requests need vision inputs
max_image_offset
=
np
.
array
(
[
max
(
image_offsets
[
i
])
if
image_offsets
[
i
]
else
-
1
for
i
in
range
(
bs
)]
)
start_positions
=
positions
[
input_metadata
.
extend_start_loc
].
cpu
().
numpy
()
need_vision
=
start_positions
<=
max_image_offset
need_vision
=
start_positions
<=
np
.
array
(
max_image_offset
)
if
need_vision
.
any
():
pixel_values
=
[
pixel_values
[
i
]
for
i
in
range
(
bs
)
if
need_vision
[
i
]]
image_sizes
=
[
image_sizes
[
i
]
for
i
in
range
(
bs
)
if
need_vision
[
i
]]
pixel_values
=
[
image_inputs
[
i
].
pixel_values
for
i
in
range
(
bs
)
if
need_vision
[
i
]
]
image_sizes
=
[
image_inputs
[
i
].
image_sizes
for
i
in
range
(
bs
)
if
need_vision
[
i
]
]
image_offsets
=
[
image_inputs
[
i
].
image_offsets
for
i
in
range
(
bs
)
if
need_vision
[
i
]
]
########## Encode Image ########
...
...
python/sglang/srt/models/llavavid.py
View file @
fd9ad817
...
...
@@ -26,7 +26,8 @@ from vllm.config import CacheConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
,
InputMetadata
from
sglang.srt.managers.schedule_batch
import
ImageInputs
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.models.llama
import
LlamaForCausalLM
...
...
@@ -54,17 +55,12 @@ class LlavaVidForCausalLM(nn.Module):
torch
.
empty
(
config
.
text_config
.
hidden_size
,
dtype
=
torch
.
float16
)
)
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
pad_value
:
List
[
int
],
pixel_values
:
List
,
image_sizes
:
List
[
List
[
int
]],
):
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
ImageInputs
):
pad_values
=
image_inputs
.
pad_values
new_image_feature_len
=
self
.
image_feature_len
pad_ids
=
pad_value
*
(
(
new_image_feature_len
+
len
(
pad_value
))
//
len
(
pad_value
)
pad_ids
=
pad_value
s
*
(
(
new_image_feature_len
+
len
(
pad_value
s
))
//
len
(
pad_value
s
)
)
offset
=
input_ids
.
index
(
self
.
config
.
image_token_index
)
# old_len + pad_len - 1, because we need to remove image_token_id
...
...
@@ -73,7 +69,8 @@ class LlavaVidForCausalLM(nn.Module):
+
pad_ids
[:
new_image_feature_len
]
+
input_ids
[
offset
+
1
:]
)
return
new_input_ids
,
[
offset
]
image_inputs
.
image_offsets
=
[
offset
]
return
new_input_ids
def
encode_images
(
self
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
image_outputs
=
self
.
vision_tower
(
pixel_values
,
output_hidden_states
=
True
)
...
...
@@ -112,10 +109,8 @@ class LlavaVidForCausalLM(nn.Module):
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
pixel_values
:
Optional
[
List
[
Optional
[
np
.
array
]]]
=
None
,
image_sizes
:
Optional
[
List
[
List
[
int
]]]
=
None
,
image_offsets
:
Optional
[
List
[
int
]]
=
None
,
)
->
torch
.
Tensor
:
image_inputs
=
input_metadata
.
image_inputs
if
input_metadata
.
forward_mode
.
is_extend
():
bs
=
input_metadata
.
batch_size
...
...
@@ -123,14 +118,22 @@ class LlavaVidForCausalLM(nn.Module):
input_embeds
=
self
.
language_model
.
model
.
embed_tokens
(
input_ids
)
# Whether the requests need vision inputs
max_image_offset
=
np
.
array
(
[
max
(
image_offsets
[
i
])
if
image_offsets
[
i
]
else
-
1
for
i
in
range
(
bs
)]
)
max_image_offset
=
[]
for
im
in
image_inputs
:
if
im
and
im
.
image_offsets
:
max_image_offset
.
append
(
max
(
im
.
image_offsets
))
else
:
max_image_offset
.
append
(
-
1
)
start_positions
=
positions
[
input_metadata
.
extend_start_loc
].
cpu
().
numpy
()
need_vision
=
start_positions
<=
max_image_offset
need_vision
=
start_positions
<=
np
.
array
(
max_image_offset
)
if
need_vision
.
any
():
pixel_values
=
[
pixel_values
[
i
]
for
i
in
range
(
bs
)
if
need_vision
[
i
]]
pixel_values
=
[
image_inputs
[
i
].
pixel_values
for
i
in
range
(
bs
)
if
need_vision
[
i
]
]
image_offsets
=
[
image_inputs
[
i
].
image_offsets
for
i
in
range
(
bs
)
if
need_vision
[
i
]
]
########## Encode Image ########
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment