Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
fd9ad817
Unverified
Commit
fd9ad817
authored
Sep 28, 2024
by
Liangsheng Yin
Committed by
GitHub
Sep 29, 2024
Browse files
Organize image inputs (#1531)
parent
e165a9fc
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
121 additions
and
132 deletions
+121
-132
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+2
-8
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+37
-14
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+16
-21
python/sglang/srt/managers/tp_worker.py
python/sglang/srt/managers/tp_worker.py
+10
-22
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+3
-10
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-14
python/sglang/srt/models/llava.py
python/sglang/srt/models/llava.py
+30
-24
python/sglang/srt/models/llavavid.py
python/sglang/srt/models/llavavid.py
+22
-19
No files found.
python/sglang/srt/managers/io_struct.py
View file @
fd9ad817
...
@@ -172,12 +172,8 @@ class TokenizedGenerateReqInput:
...
@@ -172,12 +172,8 @@ class TokenizedGenerateReqInput:
input_text
:
str
input_text
:
str
# The input token ids
# The input token ids
input_ids
:
List
[
int
]
input_ids
:
List
[
int
]
# The pixel values for input images
# The image input
pixel_values
:
List
[
float
]
image_inputs
:
dict
# The hash values of input images
image_hashes
:
List
[
int
]
# The image sizes
image_sizes
:
List
[
List
[
int
]]
# The sampling parameters
# The sampling parameters
sampling_params
:
SamplingParams
sampling_params
:
SamplingParams
# Whether to return the logprobs
# Whether to return the logprobs
...
@@ -188,8 +184,6 @@ class TokenizedGenerateReqInput:
...
@@ -188,8 +184,6 @@ class TokenizedGenerateReqInput:
top_logprobs_num
:
int
top_logprobs_num
:
int
# Whether to stream output
# Whether to stream output
stream
:
bool
stream
:
bool
# Modalities of the input images
modalites
:
Optional
[
List
[
str
]]
=
None
# LoRA related
# LoRA related
lora_path
:
Optional
[
str
]
=
None
# None means just use the base model
lora_path
:
Optional
[
str
]
=
None
# None means just use the base model
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
fd9ad817
...
@@ -102,6 +102,39 @@ class FINISH_ABORT(BaseFinishReason):
...
@@ -102,6 +102,39 @@ class FINISH_ABORT(BaseFinishReason):
}
}
@
dataclass
class
ImageInputs
:
pixel_values
:
torch
.
Tensor
image_hash
:
int
image_sizes
:
Optional
[
list
]
=
None
image_offsets
:
Optional
[
list
]
=
None
pad_values
:
Optional
[
list
]
=
None
modalities
:
Optional
[
list
]
=
None
image_embeds
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
aspect_ratio_ids
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
aspect_ratio_mask
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
@
staticmethod
def
from_dict
(
obj
,
vocab_size
):
# Use image hash as fake token_ids, which is then used for prefix matching
ret
=
ImageInputs
(
pixel_values
=
obj
[
"pixel_values"
],
image_hash
=
hash
(
tuple
(
obj
[
"image_hashes"
])),
)
image_hash
=
ret
.
image_hash
ret
.
pad_values
=
[
(
image_hash
)
%
vocab_size
,
(
image_hash
>>
16
)
%
vocab_size
,
(
image_hash
>>
32
)
%
vocab_size
,
(
image_hash
>>
64
)
%
vocab_size
,
]
ret
.
image_sizes
=
obj
[
"image_sizes"
]
# Only when pixel values is not None we have modalities
ret
.
modalities
=
obj
[
"modalities"
]
return
ret
class
Req
:
class
Req
:
"""Store all inforamtion of a request."""
"""Store all inforamtion of a request."""
...
@@ -147,11 +180,7 @@ class Req:
...
@@ -147,11 +180,7 @@ class Req:
self
.
completion_tokens_wo_jump_forward
=
0
self
.
completion_tokens_wo_jump_forward
=
0
# For vision inputs
# For vision inputs
self
.
pixel_values
=
None
self
.
image_inputs
:
Optional
[
ImageInputs
]
=
None
self
.
image_sizes
=
None
self
.
image_offsets
=
None
self
.
pad_value
=
None
self
.
modalities
=
None
# Prefix info
# Prefix info
self
.
prefix_indices
=
[]
self
.
prefix_indices
=
[]
...
@@ -654,15 +683,9 @@ class ScheduleBatch:
...
@@ -654,15 +683,9 @@ class ScheduleBatch:
self
.
tree_cache
.
cache_finished_req
(
req
,
cur_all_ids
)
self
.
tree_cache
.
cache_finished_req
(
req
,
cur_all_ids
)
# re-applying image padding
# re-applying image padding
if
req
.
pixel_values
is
not
None
:
if
req
.
image_inputs
is
not
None
:
(
req
.
origin_input_ids
=
model_runner
.
model
.
pad_input_ids
(
req
.
origin_input_ids
,
req
.
origin_input_ids_unpadded
,
req
.
image_inputs
req
.
image_offsets
,
)
=
model_runner
.
model
.
pad_input_ids
(
req
.
origin_input_ids_unpadded
,
req
.
pad_value
,
req
.
pixel_values
,
req
.
image_sizes
,
)
)
jump_forward_reqs
.
append
(
req
)
jump_forward_reqs
.
append
(
req
)
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
fd9ad817
...
@@ -194,10 +194,9 @@ class TokenizerManager:
...
@@ -194,10 +194,9 @@ class TokenizerManager:
)
)
if
self
.
is_generation
:
if
self
.
is_generation
:
pixel_values
,
image_hashes
,
image_size
s
=
await
self
.
_get_
pixel_value
s
(
image_input
s
=
await
self
.
_get_
image_input
s
(
obj
.
image_data
if
not_use_index
else
obj
.
image_data
[
index
]
obj
,
obj
.
image_data
if
not_use_index
else
obj
.
image_data
[
index
]
)
)
modalities
=
obj
.
modalities
return_logprob
=
(
return_logprob
=
(
obj
.
return_logprob
if
not_use_index
else
obj
.
return_logprob
[
index
]
obj
.
return_logprob
if
not_use_index
else
obj
.
return_logprob
[
index
]
)
)
...
@@ -248,10 +247,7 @@ class TokenizerManager:
...
@@ -248,10 +247,7 @@ class TokenizerManager:
sampling_params
=
SamplingParams
(
**
obj
.
sampling_params
[
0
])
sampling_params
=
SamplingParams
(
**
obj
.
sampling_params
[
0
])
sampling_params
.
max_new_tokens
=
0
sampling_params
.
max_new_tokens
=
0
pixel_values
,
image_hashes
,
image_sizes
=
await
self
.
_get_pixel_values
(
image_inputs
=
await
self
.
_get_image_inputs
(
obj
,
obj
.
image_data
[
0
])
obj
.
image_data
[
0
]
)
modalities
=
obj
.
modalities
return_logprob
=
obj
.
return_logprob
[
0
]
return_logprob
=
obj
.
return_logprob
[
0
]
logprob_start_len
=
obj
.
logprob_start_len
[
0
]
logprob_start_len
=
obj
.
logprob_start_len
[
0
]
top_logprobs_num
=
obj
.
top_logprobs_num
[
0
]
top_logprobs_num
=
obj
.
top_logprobs_num
[
0
]
...
@@ -262,15 +258,12 @@ class TokenizerManager:
...
@@ -262,15 +258,12 @@ class TokenizerManager:
rid
,
rid
,
input_text
,
input_text
,
input_ids
,
input_ids
,
pixel_values
,
image_inputs
,
image_hashes
,
image_sizes
,
sampling_params
,
sampling_params
,
return_logprob
,
return_logprob
,
logprob_start_len
,
logprob_start_len
,
top_logprobs_num
,
top_logprobs_num
,
obj
.
stream
,
obj
.
stream
,
modalities
,
(
(
obj
.
lora_path
[
index
]
obj
.
lora_path
[
index
]
if
isinstance
(
obj
.
lora_path
,
list
)
if
isinstance
(
obj
.
lora_path
,
list
)
...
@@ -369,24 +362,20 @@ class TokenizerManager:
...
@@ -369,24 +362,20 @@ class TokenizerManager:
sampling_params
=
self
.
_get_sampling_params
(
obj
.
sampling_params
[
index
])
sampling_params
=
self
.
_get_sampling_params
(
obj
.
sampling_params
[
index
])
if
self
.
is_generation
:
if
self
.
is_generation
:
pixel_values
,
image_hashes
,
image_sizes
=
(
image_inputs
=
await
self
.
_get_image_inputs
(
await
self
.
_get_pixel_values
(
obj
.
image_data
[
index
]
)
obj
,
obj
.
image_data
[
index
]
)
)
modalities
=
obj
.
modalities
tokenized_obj
=
TokenizedGenerateReqInput
(
tokenized_obj
=
TokenizedGenerateReqInput
(
rid
,
rid
,
input_text
,
input_text
,
input_ids
,
input_ids
,
pixel_values
,
image_inputs
,
image_hashes
,
image_sizes
,
sampling_params
,
sampling_params
,
obj
.
return_logprob
[
index
],
obj
.
return_logprob
[
index
],
obj
.
logprob_start_len
[
index
],
obj
.
logprob_start_len
[
index
],
obj
.
top_logprobs_num
[
index
],
obj
.
top_logprobs_num
[
index
],
obj
.
stream
,
obj
.
stream
,
modalities
,
(
(
obj
.
lora_path
[
index
]
obj
.
lora_path
[
index
]
if
isinstance
(
obj
.
lora_path
,
list
)
if
isinstance
(
obj
.
lora_path
,
list
)
...
@@ -697,10 +686,11 @@ class TokenizerManager:
...
@@ -697,10 +686,11 @@ class TokenizerManager:
)
)
return
top_logprobs
return
top_logprobs
async
def
_get_
pixel_value
s
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]]):
async
def
_get_
image_input
s
(
self
,
obj
,
image_data
:
List
[
Union
[
str
,
bytes
]]):
if
not
image_data
:
if
not
image_data
:
return
None
,
None
,
None
return
None
# TODO: move this into a processor for each vision architecture
aspect_ratio
=
getattr
(
self
.
hf_config
,
"image_aspect_ratio"
,
None
)
aspect_ratio
=
getattr
(
self
.
hf_config
,
"image_aspect_ratio"
,
None
)
grid_pinpoints
=
(
grid_pinpoints
=
(
self
.
hf_config
.
image_grid_pinpoints
self
.
hf_config
.
image_grid_pinpoints
...
@@ -741,7 +731,12 @@ class TokenizerManager:
...
@@ -741,7 +731,12 @@ class TokenizerManager:
else
:
else
:
raise
ValueError
(
f
"Invalid image data:
{
image_data
}
"
)
raise
ValueError
(
f
"Invalid image data:
{
image_data
}
"
)
return
pixel_values
,
image_hashes
,
image_sizes
return
{
"pixel_values"
:
pixel_values
,
"image_hashes"
:
image_hashes
,
"image_sizes"
:
image_sizes
,
"modalities"
:
obj
.
modalities
,
}
async
def
_process_single_image
(
async
def
_process_single_image
(
self
,
image_data
:
Union
[
bytes
,
str
],
aspect_ratio
:
str
,
grid_pinpoints
:
str
self
,
image_data
:
Union
[
bytes
,
str
],
aspect_ratio
:
str
,
grid_pinpoints
:
str
...
...
python/sglang/srt/managers/tp_worker.py
View file @
fd9ad817
...
@@ -49,6 +49,7 @@ from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder
...
@@ -49,6 +49,7 @@ from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder
from
sglang.srt.managers.schedule_batch
import
(
from
sglang.srt.managers.schedule_batch
import
(
FINISH_ABORT
,
FINISH_ABORT
,
BaseFinishReason
,
BaseFinishReason
,
ImageInputs
,
Req
,
Req
,
ScheduleBatch
,
ScheduleBatch
,
)
)
...
@@ -340,29 +341,16 @@ class ModelTpServer:
...
@@ -340,29 +341,16 @@ class ModelTpServer:
req
=
Req
(
recv_req
.
rid
,
recv_req
.
input_text
,
recv_req
.
input_ids
)
req
=
Req
(
recv_req
.
rid
,
recv_req
.
input_text
,
recv_req
.
input_ids
)
req
.
tokenizer
=
self
.
tokenizer
req
.
tokenizer
=
self
.
tokenizer
req
.
sampling_params
=
recv_req
.
sampling_params
req
.
sampling_params
=
recv_req
.
sampling_params
req
.
pixel_values
=
recv_req
.
pixel_values
if
req
.
pixel_values
is
not
None
:
# Image inputs
# Use image hash as fake token_ids, which is then used
if
recv_req
.
image_inputs
is
not
None
:
# for prefix matching
req
.
image_inputs
=
ImageInputs
.
from_dict
(
image_hash
=
hash
(
tuple
(
recv_req
.
image_hashes
))
recv_req
.
image_inputs
,
self
.
model_config
.
vocab_size
req
.
pad_value
=
[
(
image_hash
)
%
self
.
model_config
.
vocab_size
,
(
image_hash
>>
16
)
%
self
.
model_config
.
vocab_size
,
(
image_hash
>>
32
)
%
self
.
model_config
.
vocab_size
,
(
image_hash
>>
64
)
%
self
.
model_config
.
vocab_size
,
]
req
.
image_sizes
=
recv_req
.
image_sizes
(
req
.
origin_input_ids
,
req
.
image_offsets
,
)
=
self
.
model_runner
.
model
.
pad_input_ids
(
req
.
origin_input_ids_unpadded
,
req
.
pad_value
,
req
.
pixel_values
,
req
.
image_sizes
,
)
)
# Only when pixel values is not None we have modalities
req
.
origin_input_ids
=
self
.
model_runner
.
model
.
pad_input_ids
(
req
.
modalities
=
recv_req
.
modalites
req
.
origin_input_ids_unpadded
,
req
.
image_inputs
)
req
.
return_logprob
=
recv_req
.
return_logprob
req
.
return_logprob
=
recv_req
.
return_logprob
req
.
top_logprobs_num
=
recv_req
.
top_logprobs_num
req
.
top_logprobs_num
=
recv_req
.
top_logprobs_num
req
.
stream
=
recv_req
.
stream
req
.
stream
=
recv_req
.
stream
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
fd9ad817
...
@@ -25,7 +25,7 @@ import torch
...
@@ -25,7 +25,7 @@ import torch
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.attention_backend
import
AttentionBackend
from
sglang.srt.layers.attention_backend
import
AttentionBackend
from
sglang.srt.managers.schedule_batch
import
ScheduleBatch
from
sglang.srt.managers.schedule_batch
import
ImageInputs
,
ScheduleBatch
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.mem_cache.memory_pool
import
BaseTokenToKVPool
,
ReqToTokenPool
from
sglang.srt.model_executor.model_runner
import
ModelRunner
from
sglang.srt.model_executor.model_runner
import
ModelRunner
...
@@ -84,17 +84,10 @@ class InputMetadata:
...
@@ -84,17 +84,10 @@ class InputMetadata:
extend_logprob_start_lens_cpu
:
List
[
int
]
=
None
extend_logprob_start_lens_cpu
:
List
[
int
]
=
None
# For multimodal
# For multimodal
pixel_values
:
List
[
torch
.
Tensor
]
=
None
image_inputs
:
List
[
ImageInputs
]
=
None
image_sizes
:
List
[
List
[
List
[
int
]]]
=
None
image_offsets
:
List
[
List
[
int
]]
=
None
modalities
:
List
[
List
[
str
]]
=
None
def
init_multimuldal_info
(
self
,
batch
:
ScheduleBatch
):
def
init_multimuldal_info
(
self
,
batch
:
ScheduleBatch
):
reqs
=
batch
.
reqs
self
.
image_inputs
=
[
r
.
image_inputs
for
r
in
batch
.
reqs
]
self
.
pixel_values
=
[
r
.
pixel_values
for
r
in
reqs
]
self
.
image_sizes
=
[
r
.
image_sizes
for
r
in
reqs
]
self
.
image_offsets
=
[
r
.
image_offsets
for
r
in
reqs
]
self
.
modalities
=
[
r
.
modalities
for
r
in
reqs
]
def
compute_positions
(
self
,
batch
:
ScheduleBatch
):
def
compute_positions
(
self
,
batch
:
ScheduleBatch
):
if
self
.
forward_mode
.
is_decode
():
if
self
.
forward_mode
.
is_decode
():
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
fd9ad817
...
@@ -498,23 +498,10 @@ class ModelRunner:
...
@@ -498,23 +498,10 @@ class ModelRunner:
get_embedding
=
True
,
get_embedding
=
True
,
)
)
def
forward_extend_multi_modal
(
self
,
batch
:
ScheduleBatch
):
input_metadata
=
InputMetadata
.
from_schedule_batch
(
self
,
batch
)
return
self
.
model
.
forward
(
batch
.
input_ids
,
input_metadata
.
positions
,
input_metadata
,
input_metadata
.
pixel_values
,
input_metadata
.
image_sizes
,
input_metadata
.
image_offsets
,
)
def
forward
(
self
,
batch
:
ScheduleBatch
)
->
Tuple
[
LogitsProcessorOutput
]:
def
forward
(
self
,
batch
:
ScheduleBatch
)
->
Tuple
[
LogitsProcessorOutput
]:
assert
batch
.
forward_mode
is
not
None
assert
batch
.
forward_mode
is
not
None
if
self
.
is_multimodal_model
and
batch
.
forward_mode
.
is_extend
():
if
batch
.
forward_mode
.
is_decode
():
return
self
.
forward_extend_multi_modal
(
batch
)
elif
batch
.
forward_mode
.
is_decode
():
return
self
.
forward_decode
(
batch
)
return
self
.
forward_decode
(
batch
)
elif
batch
.
forward_mode
.
is_extend
():
elif
batch
.
forward_mode
.
is_extend
():
return
self
.
forward_extend
(
batch
)
return
self
.
forward_extend
(
batch
)
...
...
python/sglang/srt/models/llava.py
View file @
fd9ad817
...
@@ -35,25 +35,22 @@ from vllm.config import CacheConfig
...
@@ -35,25 +35,22 @@ from vllm.config import CacheConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.managers.schedule_batch
import
ImageInputs
from
sglang.srt.mm_utils
import
(
from
sglang.srt.mm_utils
import
(
get_anyres_image_grid_shape
,
get_anyres_image_grid_shape
,
unpad_image
,
unpad_image
,
unpad_image_shape
,
unpad_image_shape
,
)
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
,
InputMetadata
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.models.llama
import
LlamaForCausalLM
from
sglang.srt.models.llama
import
LlamaForCausalLM
from
sglang.srt.models.mistral
import
MistralForCausalLM
from
sglang.srt.models.mistral
import
MistralForCausalLM
from
sglang.srt.models.qwen2
import
Qwen2ForCausalLM
from
sglang.srt.models.qwen2
import
Qwen2ForCausalLM
class
LlavaBaseForCausalLM
(
nn
.
Module
):
class
LlavaBaseForCausalLM
(
nn
.
Module
):
def
pad_input_ids
(
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
ImageInputs
):
self
,
image_sizes
,
pad_values
=
image_inputs
.
image_sizes
,
image_inputs
.
pad_values
input_ids
:
List
[
int
],
pad_value
:
List
[
int
],
pixel_values
:
List
,
image_sizes
:
List
[
List
[
int
]],
):
# hardcode for spatial_unpad + anyres
# hardcode for spatial_unpad + anyres
image_aspect_ratio
=
"anyres"
if
len
(
image_sizes
)
==
1
else
"pad"
image_aspect_ratio
=
"anyres"
if
len
(
image_sizes
)
==
1
else
"pad"
offset_list
=
[]
offset_list
=
[]
...
@@ -92,8 +89,8 @@ class LlavaBaseForCausalLM(nn.Module):
...
@@ -92,8 +89,8 @@ class LlavaBaseForCausalLM(nn.Module):
new_w
=
int
(
new_w
//
times
)
new_w
=
int
(
new_w
//
times
)
new_image_feature_len
+=
new_h
*
(
new_w
+
1
)
new_image_feature_len
+=
new_h
*
(
new_w
+
1
)
pad_ids
=
pad_value
*
(
pad_ids
=
pad_value
s
*
(
(
new_image_feature_len
+
len
(
pad_value
))
//
len
(
pad_value
)
(
new_image_feature_len
+
len
(
pad_value
s
))
//
len
(
pad_value
s
)
)
)
# print("calculated new_image_feature_len: ", new_image_feature_len)
# print("calculated new_image_feature_len: ", new_image_feature_len)
try
:
try
:
...
@@ -107,7 +104,9 @@ class LlavaBaseForCausalLM(nn.Module):
...
@@ -107,7 +104,9 @@ class LlavaBaseForCausalLM(nn.Module):
+
input_ids
[
offset
+
1
:]
+
input_ids
[
offset
+
1
:]
)
)
offset_list
.
append
(
offset
)
offset_list
.
append
(
offset
)
return
input_ids
,
offset_list
image_inputs
.
image_offsets
=
offset_list
return
input_ids
def
encode_images
(
self
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
encode_images
(
self
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
image_outputs
=
self
.
vision_tower
(
pixel_values
,
output_hidden_states
=
True
)
image_outputs
=
self
.
vision_tower
(
pixel_values
,
output_hidden_states
=
True
)
...
@@ -132,32 +131,39 @@ class LlavaBaseForCausalLM(nn.Module):
...
@@ -132,32 +131,39 @@ class LlavaBaseForCausalLM(nn.Module):
input_ids
:
torch
.
LongTensor
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
pixel_values
:
Optional
[
List
[
Optional
[
np
.
array
]]]
=
None
,
image_sizes
:
Optional
[
List
[
List
[
int
]]]
=
None
,
image_offsets
:
Optional
[
List
[
int
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
image_inputs
=
input_metadata
.
image_inputs
if
input_metadata
.
forward_mode
.
is_extend
():
if
input_metadata
.
forward_mode
.
is_extend
():
bs
=
input_metadata
.
batch_size
bs
=
input_metadata
.
batch_size
# Got List[List[str]] extend it to List[str]
# Got List[List[str]] extend it to List[str]
# The length of the List should be equal to batch size
# The length of the List should be equal to batch size
modalities_list
=
[]
modalities_list
=
[]
for
modalities
in
input_metadata
.
modalities
:
max_image_offset
=
[]
if
modalities
is
not
None
:
for
im
in
image_inputs
:
modalities_list
.
extend
(
modalities
)
if
im
and
im
.
modalities
is
not
None
:
modalities_list
.
extend
(
im
.
modalities
)
if
im
and
im
.
image_offsets
is
not
None
:
max_image_offset
.
append
(
max
(
im
.
image_offsets
))
else
:
max_image_offset
.
append
(
-
1
)
# Embed text inputs
# Embed text inputs
input_embeds
=
self
.
language_model
.
model
.
embed_tokens
(
input_ids
)
input_embeds
=
self
.
language_model
.
model
.
embed_tokens
(
input_ids
)
# Whether the requests need vision inputs
max_image_offset
=
np
.
array
(
[
max
(
image_offsets
[
i
])
if
image_offsets
[
i
]
else
-
1
for
i
in
range
(
bs
)]
)
start_positions
=
positions
[
input_metadata
.
extend_start_loc
].
cpu
().
numpy
()
start_positions
=
positions
[
input_metadata
.
extend_start_loc
].
cpu
().
numpy
()
need_vision
=
start_positions
<=
max_image_offset
need_vision
=
start_positions
<=
np
.
array
(
max_image_offset
)
if
need_vision
.
any
():
if
need_vision
.
any
():
pixel_values
=
[
pixel_values
[
i
]
for
i
in
range
(
bs
)
if
need_vision
[
i
]]
pixel_values
=
[
image_sizes
=
[
image_sizes
[
i
]
for
i
in
range
(
bs
)
if
need_vision
[
i
]]
image_inputs
[
i
].
pixel_values
for
i
in
range
(
bs
)
if
need_vision
[
i
]
]
image_sizes
=
[
image_inputs
[
i
].
image_sizes
for
i
in
range
(
bs
)
if
need_vision
[
i
]
]
image_offsets
=
[
image_inputs
[
i
].
image_offsets
for
i
in
range
(
bs
)
if
need_vision
[
i
]
]
########## Encode Image ########
########## Encode Image ########
...
...
python/sglang/srt/models/llavavid.py
View file @
fd9ad817
...
@@ -26,7 +26,8 @@ from vllm.config import CacheConfig
...
@@ -26,7 +26,8 @@ from vllm.config import CacheConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
,
InputMetadata
from
sglang.srt.managers.schedule_batch
import
ImageInputs
from
sglang.srt.model_executor.forward_batch_info
import
InputMetadata
from
sglang.srt.models.llama
import
LlamaForCausalLM
from
sglang.srt.models.llama
import
LlamaForCausalLM
...
@@ -54,17 +55,12 @@ class LlavaVidForCausalLM(nn.Module):
...
@@ -54,17 +55,12 @@ class LlavaVidForCausalLM(nn.Module):
torch
.
empty
(
config
.
text_config
.
hidden_size
,
dtype
=
torch
.
float16
)
torch
.
empty
(
config
.
text_config
.
hidden_size
,
dtype
=
torch
.
float16
)
)
)
def
pad_input_ids
(
def
pad_input_ids
(
self
,
input_ids
:
List
[
int
],
image_inputs
:
ImageInputs
):
self
,
pad_values
=
image_inputs
.
pad_values
input_ids
:
List
[
int
],
pad_value
:
List
[
int
],
pixel_values
:
List
,
image_sizes
:
List
[
List
[
int
]],
):
new_image_feature_len
=
self
.
image_feature_len
new_image_feature_len
=
self
.
image_feature_len
pad_ids
=
pad_value
*
(
pad_ids
=
pad_value
s
*
(
(
new_image_feature_len
+
len
(
pad_value
))
//
len
(
pad_value
)
(
new_image_feature_len
+
len
(
pad_value
s
))
//
len
(
pad_value
s
)
)
)
offset
=
input_ids
.
index
(
self
.
config
.
image_token_index
)
offset
=
input_ids
.
index
(
self
.
config
.
image_token_index
)
# old_len + pad_len - 1, because we need to remove image_token_id
# old_len + pad_len - 1, because we need to remove image_token_id
...
@@ -73,7 +69,8 @@ class LlavaVidForCausalLM(nn.Module):
...
@@ -73,7 +69,8 @@ class LlavaVidForCausalLM(nn.Module):
+
pad_ids
[:
new_image_feature_len
]
+
pad_ids
[:
new_image_feature_len
]
+
input_ids
[
offset
+
1
:]
+
input_ids
[
offset
+
1
:]
)
)
return
new_input_ids
,
[
offset
]
image_inputs
.
image_offsets
=
[
offset
]
return
new_input_ids
def
encode_images
(
self
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
encode_images
(
self
,
pixel_values
:
torch
.
Tensor
)
->
torch
.
Tensor
:
image_outputs
=
self
.
vision_tower
(
pixel_values
,
output_hidden_states
=
True
)
image_outputs
=
self
.
vision_tower
(
pixel_values
,
output_hidden_states
=
True
)
...
@@ -112,10 +109,8 @@ class LlavaVidForCausalLM(nn.Module):
...
@@ -112,10 +109,8 @@ class LlavaVidForCausalLM(nn.Module):
input_ids
:
torch
.
LongTensor
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_metadata
:
InputMetadata
,
pixel_values
:
Optional
[
List
[
Optional
[
np
.
array
]]]
=
None
,
image_sizes
:
Optional
[
List
[
List
[
int
]]]
=
None
,
image_offsets
:
Optional
[
List
[
int
]]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
image_inputs
=
input_metadata
.
image_inputs
if
input_metadata
.
forward_mode
.
is_extend
():
if
input_metadata
.
forward_mode
.
is_extend
():
bs
=
input_metadata
.
batch_size
bs
=
input_metadata
.
batch_size
...
@@ -123,14 +118,22 @@ class LlavaVidForCausalLM(nn.Module):
...
@@ -123,14 +118,22 @@ class LlavaVidForCausalLM(nn.Module):
input_embeds
=
self
.
language_model
.
model
.
embed_tokens
(
input_ids
)
input_embeds
=
self
.
language_model
.
model
.
embed_tokens
(
input_ids
)
# Whether the requests need vision inputs
# Whether the requests need vision inputs
max_image_offset
=
np
.
array
(
max_image_offset
=
[]
[
max
(
image_offsets
[
i
])
if
image_offsets
[
i
]
else
-
1
for
i
in
range
(
bs
)]
for
im
in
image_inputs
:
)
if
im
and
im
.
image_offsets
:
max_image_offset
.
append
(
max
(
im
.
image_offsets
))
else
:
max_image_offset
.
append
(
-
1
)
start_positions
=
positions
[
input_metadata
.
extend_start_loc
].
cpu
().
numpy
()
start_positions
=
positions
[
input_metadata
.
extend_start_loc
].
cpu
().
numpy
()
need_vision
=
start_positions
<=
max_image_offset
need_vision
=
start_positions
<=
np
.
array
(
max_image_offset
)
if
need_vision
.
any
():
if
need_vision
.
any
():
pixel_values
=
[
pixel_values
[
i
]
for
i
in
range
(
bs
)
if
need_vision
[
i
]]
pixel_values
=
[
image_inputs
[
i
].
pixel_values
for
i
in
range
(
bs
)
if
need_vision
[
i
]
]
image_offsets
=
[
image_inputs
[
i
].
image_offsets
for
i
in
range
(
bs
)
if
need_vision
[
i
]
]
########## Encode Image ########
########## Encode Image ########
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment