Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5cb552b1
Unverified
Commit
5cb552b1
authored
Apr 01, 2025
by
Mick
Committed by
GitHub
Mar 31, 2025
Browse files
refactor: multimodal data (#4754)
parent
c7457191
Changes
36
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
564 additions
and
662 deletions
+564
-662
benchmark/mmmu/bench_hf.py
benchmark/mmmu/bench_hf.py
+32
-11
benchmark/mmmu/eval_utils.py
benchmark/mmmu/eval_utils.py
+2
-0
python/sglang/srt/managers/mm_utils.py
python/sglang/srt/managers/mm_utils.py
+202
-156
python/sglang/srt/managers/multimodal_processor.py
python/sglang/srt/managers/multimodal_processor.py
+0
-2
python/sglang/srt/managers/multimodal_processors/base_processor.py
...lang/srt/managers/multimodal_processors/base_processor.py
+36
-72
python/sglang/srt/managers/multimodal_processors/clip.py
python/sglang/srt/managers/multimodal_processors/clip.py
+7
-26
python/sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py
...lang/srt/managers/multimodal_processors/deepseek_vl_v2.py
+17
-58
python/sglang/srt/managers/multimodal_processors/gemma3.py
python/sglang/srt/managers/multimodal_processors/gemma3.py
+12
-27
python/sglang/srt/managers/multimodal_processors/janus_pro.py
...on/sglang/srt/managers/multimodal_processors/janus_pro.py
+21
-47
python/sglang/srt/managers/multimodal_processors/llava.py
python/sglang/srt/managers/multimodal_processors/llava.py
+23
-12
python/sglang/srt/managers/multimodal_processors/minicpm.py
python/sglang/srt/managers/multimodal_processors/minicpm.py
+35
-38
python/sglang/srt/managers/multimodal_processors/mlama.py
python/sglang/srt/managers/multimodal_processors/mlama.py
+10
-23
python/sglang/srt/managers/multimodal_processors/qwen_vl.py
python/sglang/srt/managers/multimodal_processors/qwen_vl.py
+21
-44
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+116
-94
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+1
-1
python/sglang/srt/managers/utils.py
python/sglang/srt/managers/utils.py
+1
-6
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+0
-5
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+6
-18
python/sglang/srt/models/clip.py
python/sglang/srt/models/clip.py
+12
-7
python/sglang/srt/models/deepseek_janus_pro.py
python/sglang/srt/models/deepseek_janus_pro.py
+10
-15
No files found.
benchmark/mmmu/bench_hf.py
View file @
5cb552b1
...
...
@@ -72,17 +72,38 @@ def eval_mmmu(args):
if
suffix
:
contents
+=
[{
"type"
:
"text"
,
"text"
:
suffix
}]
messages
=
[{
"role"
:
"user"
,
"content"
:
contents
}]
model_inputs
=
processor
.
apply_chat_template
(
messages
,
tokenize
=
True
,
return_dict
=
True
,
add_generation_prompt
=
True
,
return_tensors
=
"pt"
,
).
to
(
model
.
device
)
input_len
=
model_inputs
[
"input_ids"
].
shape
[
-
1
]
generation
=
model
.
generate
(
**
model_inputs
,
generation_config
=
generation_config
)
generation
=
generation
[
0
][
input_len
:]
response
=
processor
.
decode
(
generation
,
skip_special_tokens
=
True
)
try
:
model_inputs
=
processor
.
tokenizer
.
apply_chat_template
(
messages
,
tokenize
=
True
,
return_dict
=
True
,
add_generation_prompt
=
True
,
return_tensors
=
"pt"
,
).
to
(
model
.
device
)
input_len
=
model_inputs
[
"input_ids"
].
shape
[
-
1
]
generation
=
model
.
generate
(
**
model_inputs
,
generation_config
=
generation_config
)
generation
=
generation
[
0
][
input_len
:]
response
=
processor
.
decode
(
generation
,
skip_special_tokens
=
True
)
except
:
contents
=
[]
if
prefix
:
contents
+=
[
prefix
]
image
=
PIL
.
Image
.
open
(
sample
[
"image_path"
])
contents
+=
[
image
]
if
suffix
:
contents
+=
[
suffix
]
messages
=
[{
"role"
:
"user"
,
"content"
:
contents
}]
response
=
model
.
chat
(
msgs
=
messages
,
tokenizer
=
processor
.
tokenizer
,
sampling
=
False
,
max_new_tokens
=
sampling_params
[
"max_new_tokens"
],
use_tts_template
=
False
,
generate_audio
=
False
,
temperature
=
0.0
,
)
print
(
f
"response:
{
response
}
"
)
process_result
(
response
,
sample
,
answer_dict
,
out_samples
)
...
...
benchmark/mmmu/eval_utils.py
View file @
5cb552b1
...
...
@@ -442,6 +442,8 @@ def calculate_ins_level_acc(results: Dict):
def
process_result
(
response
,
sample
,
answer_dict
,
out_samples
):
if
response
is
None
:
return
if
sample
[
"question_type"
]
==
"multiple-choice"
:
pred_ans
=
parse_multi_choice_response
(
response
,
sample
[
"all_choices"
],
sample
[
"index2ans"
]
...
...
python/sglang/srt/managers/mm_utils.py
View file @
5cb552b1
This diff is collapsed.
Click to expand it.
python/sglang/srt/managers/multimodal_processor.py
View file @
5cb552b1
...
...
@@ -64,5 +64,3 @@ def get_mm_processor(
f
"No processor registered for architecture:
{
hf_config
.
architectures
}
.
\n
"
f
"Registered architectures:
{
[
model_cls
.
__name__
for
model_cls
in
PROCESSOR_MAPPING
.
keys
()]
}
"
)
self
.
image_proce
python/sglang/srt/managers/multimodal_processors/base_processor.py
View file @
5cb552b1
...
...
@@ -8,18 +8,10 @@ from typing import Optional
import
numpy
as
np
import
PIL
import
transformers
from
decord
import
VideoReader
,
cpu
from
PIL
import
Image
from
sglang.srt.utils
import
load_audio
,
load_image
,
logger
global
global_processor
def
get_global_processor
():
global
global_processor
return
global_processor
from
sglang.srt.utils
import
encode_video
,
load_audio
,
load_image
,
logger
@
dataclasses
.
dataclass
...
...
@@ -27,9 +19,6 @@ class BaseMultiModalProcessorOutput:
# input_text, with each frame of video/image represented with a image_token
input_text
:
str
mm_data_hashes
:
Optional
[
list
[
int
]]
# images
image_sizes
:
Optional
[
list
[
int
]]
# frames loaded from image and video, in given order
images
:
Optional
[
list
[
PIL
.
Image
]]
=
None
...
...
@@ -37,7 +26,7 @@ class BaseMultiModalProcessorOutput:
audios
:
Optional
[
list
[
np
.
ndarray
]]
=
None
def
normalize
(
self
):
for
field_name
in
[
"data_hashes"
,
"image_sizes"
,
"images"
,
"audios"
]:
for
field_name
in
[
"image_sizes"
,
"images"
,
"audios"
]:
field
=
getattr
(
self
,
field_name
,
None
)
if
field
is
not
None
and
isinstance
(
field
,
list
)
and
len
(
field
)
==
0
:
setattr
(
self
,
field_name
,
None
)
...
...
@@ -67,28 +56,35 @@ class BaseMultimodalProcessor(ABC):
# FIXME: not accurate, model and image specific
self
.
NUM_TOKEN_PER_FRAME
=
330
# Initialize global processor first
init_global_processor
(
self
,
server_args
)
self
.
executor
=
concurrent
.
futures
.
ProcessPoolExecutor
(
initializer
=
init_global_processor
,
self
.
io_executor
=
concurrent
.
futures
.
ThreadPoolExecutor
(
max_workers
=
int
(
os
.
environ
.
get
(
"SGLANG_IO_WORKERS"
,
4
))
)
self
.
cpu_executor
=
concurrent
.
futures
.
ProcessPoolExecutor
(
mp_context
=
mp
.
get_context
(
"fork"
),
initargs
=
(
self
,
server_args
,
),
max_workers
=
int
(
os
.
environ
.
get
(
"SGLANG_CPU_COUNT"
,
os
.
cpu_count
())),
max_workers
=
int
(
os
.
environ
.
get
(
"SGLANG_CPU_WORKERS"
,
os
.
cpu_count
())),
)
def
_build_processor
(
self
,
server_args
):
"""Init the global processor for multi modal models."""
from
sglang.srt.hf_transformers_utils
import
get_processor
return
get_processor
(
server_args
.
tokenizer_path
,
tokenizer_mode
=
server_args
.
tokenizer_mode
,
trust_remote_code
=
server_args
.
trust_remote_code
,
def
process_mm_data
(
self
,
input_text
,
images
=
None
,
videos
=
None
,
audios
=
None
,
**
kwargs
):
"""
process multimodal data with transformers AutoProcessor
"""
if
images
is
not
None
:
kwargs
[
"images"
]
=
images
if
videos
is
not
None
:
kwargs
[
"videos"
]
=
videos
if
audios
is
not
None
:
kwargs
[
"audios"
]
=
audios
processor
=
self
.
_processor
result
=
processor
.
__call__
(
text
=
[
input_text
],
padding
=
True
,
return_tensors
=
"pt"
,
**
kwargs
,
)
return
result
@
abstractmethod
async
def
process_mm_data_async
(
...
...
@@ -115,33 +111,9 @@ class BaseMultimodalProcessor(ABC):
return
estimated_frames_list
@
staticmethod
def
encode_video
(
video_path
,
frame_count_limit
=
None
):
if
not
os
.
path
.
exists
(
video_path
):
logger
.
error
(
f
"Video
{
video_path
}
does not exist"
)
return
[]
if
frame_count_limit
==
0
:
return
[]
def
uniform_sample
(
l
,
n
):
gap
=
len
(
l
)
/
n
idxs
=
[
int
(
i
*
gap
+
gap
/
2
)
for
i
in
range
(
n
)]
return
[
l
[
i
]
for
i
in
idxs
]
vr
=
VideoReader
(
video_path
,
ctx
=
cpu
(
0
))
sample_fps
=
round
(
vr
.
get_avg_fps
()
/
1
)
# FPS
frame_indices
=
[
i
for
i
in
range
(
0
,
len
(
vr
),
sample_fps
)]
if
frame_count_limit
is
not
None
and
len
(
frame_indices
)
>
frame_count_limit
:
frame_indices
=
uniform_sample
(
frame_indices
,
frame_count_limit
)
frames
=
vr
.
get_batch
(
frame_indices
).
asnumpy
()
frames
=
[
Image
.
fromarray
(
v
.
astype
(
"uint8"
))
for
v
in
frames
]
return
frames
def
load_mm_data
(
self
,
input_ids
:
list
[
int
]
,
prompt
:
str
,
multimodal_tokens
:
MultimodalSpecialTokens
,
max_req_input_len
:
int
,
image_data
:
Optional
[
list
]
=
None
,
...
...
@@ -167,11 +139,13 @@ class BaseMultimodalProcessor(ABC):
else
:
multimodal_tokens
.
image_token
=
multimodal_tokens
.
image_token
if
isinstance
(
input_ids
,
list
)
and
return_text
:
assert
len
(
input_ids
)
and
isinstance
(
input_ids
[
0
],
int
)
input_text
=
self
.
_processor
.
tokenizer
.
decode
(
input_ids
)
assert
isinstance
(
prompt
,
str
)
if
isinstance
(
prompt
,
list
)
and
return_text
:
assert
len
(
prompt
)
and
isinstance
(
prompt
[
0
],
int
)
prompt
=
self
.
_processor
.
tokenizer
.
decode
(
prompt
)
else
:
input_text
=
input_ids
prompt
=
prompt
if
return_text
:
import
re
...
...
@@ -181,7 +155,7 @@ class BaseMultimodalProcessor(ABC):
+
")"
)
# split text into list of normal text and special tokens
text_parts
=
re
.
split
(
pattern
,
input_tex
t
)
text_parts
=
re
.
split
(
pattern
,
promp
t
)
# TODO(mick): load from server_args, env, or sampling_params
MAX_NUM_FRAMES
=
30
...
...
@@ -217,7 +191,7 @@ class BaseMultimodalProcessor(ABC):
):
# video
path
=
image_file
[
len
(
"video:"
)
:]
frames
=
BaseMultimodalProcessor
.
encode_video
(
frames
=
encode_video
(
path
,
frame_count_limit
=
frames_to_process
)
else
:
...
...
@@ -254,19 +228,9 @@ class BaseMultimodalProcessor(ABC):
raise
RuntimeError
(
f
"An exception occurred while loading images:
{
e
}
"
)
out
=
BaseMultiModalProcessorOutput
(
mm_data_hashes
=
hashes
,
image_sizes
=
image_sizes
,
images
=
images
,
audios
=
audios
,
input_text
=
new_text
,
)
out
.
normalize
()
return
out
def
init_global_processor
(
sglang_processor
:
BaseMultimodalProcessor
,
server_args
):
"""
Init the global processor for multimodal models."""
global
global_processor
transformers
.
logging
.
set_verbosity_error
()
global_processor
=
sglang_processor
.
_build_processor
(
server_args
=
server_args
)
python/sglang/srt/managers/multimodal_processors/clip.py
View file @
5cb552b1
import
asyncio
from
typing
import
List
,
Union
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
BaseMultimodalProcessor
,
get_global_processor
,
)
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.models.clip
import
CLIPModel
from
sglang.srt.utils
import
load_image
...
...
@@ -15,29 +14,6 @@ class ClipImageProcessor(BaseMultimodalProcessor):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
@
staticmethod
def
_process_single_image_task
(
images
,
input_text
):
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
return
get_global_processor
()(
images
=
images
,
text
=
input_text
,
return_tensors
=
"pt"
)
async
def
_process_single_image
(
self
,
images
,
input_text
):
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
image_inputs
=
await
loop
.
run_in_executor
(
self
.
executor
,
ClipImageProcessor
.
_process_single_image_task
,
images
,
input_text
,
)
else
:
image_inputs
=
self
.
_processor
(
images
=
images
,
text
=
[
input_text
],
return_tensors
=
"pt"
)
return
image_inputs
async
def
process_mm_data_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_text
,
*
args
,
**
kwargs
):
...
...
@@ -56,8 +32,13 @@ class ClipImageProcessor(BaseMultimodalProcessor):
else
:
images
=
load_image
(
image_data
[
0
])[
0
]
image_inputs
=
await
self
.
_
process_
single_image
(
images
,
input_text
)
image_inputs
=
self
.
process_
mm_data
(
input_text
=
input_text
,
images
=
images
)
image_inputs
[
"data_hashes"
]
=
[
hash
(
str
(
image_data
))]
image_inputs
[
"input_ids"
]
=
image_inputs
[
"input_ids"
].
tolist
()[
0
]
image_inputs
[
"mm_items"
]
=
[
MultimodalDataItem
(
pixel_values
=
image_inputs
[
"pixel_values"
],
modality
=
Modality
.
IMAGE
)
]
return
image_inputs
python/sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py
View file @
5cb552b1
...
...
@@ -16,15 +16,14 @@
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import
asyncio
import
torch
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
BaseMultimodalProcessor
,
MultimodalSpecialTokens
,
get_global_processor
,
)
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.models.deepseek_vl2
import
DeepseekVL2ForCausalLM
...
...
@@ -35,51 +34,6 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
self
.
IMAGE_TOKEN
=
"<image>"
@
staticmethod
def
_process_images_task
(
image
,
input_text
,
max_req_input_len
):
processor
=
get_global_processor
()
res
=
processor
.
__call__
(
conversations
=
input_text
,
images
=
image
,
max_req_input_len
=
max_req_input_len
)
image_token_id
=
processor
.
image_token_id
res
[
"im_token_id"
]
=
image_token_id
return
res
async
def
_process_images
(
self
,
image_data
,
input_text
,
max_req_input_len
):
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
image_inputs
=
await
loop
.
run_in_executor
(
self
.
executor
,
DeepseekVL2ImageProcessor
.
_process_images_task
,
image_data
,
input_text
,
max_req_input_len
,
)
else
:
image_inputs
=
self
.
_process_images_task
(
image_data
,
input_text
,
max_req_input_len
)
return
image_inputs
async
def
_process_images
(
self
,
image_data
,
input_text
,
max_req_input_len
):
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
image_inputs
=
await
loop
.
run_in_executor
(
self
.
executor
,
DeepseekVL2ImageProcessor
.
_process_images_task
,
image_data
,
input_text
,
max_req_input_len
,
)
else
:
image_inputs
=
self
.
_process_images_task
(
image_data
,
input_text
,
max_req_input_len
)
return
image_inputs
async
def
process_mm_data_async
(
self
,
image_data
,
input_ids
,
request_obj
,
max_req_input_len
,
*
args
,
**
kwargs
):
...
...
@@ -89,8 +43,6 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
if
not
isinstance
(
image_data
,
list
):
image_data
=
[
image_data
]
images
,
image_sizes
=
[],
[]
image_token
=
self
.
IMAGE_TOKEN
base_output
=
self
.
load_mm_data
(
input_ids
,
...
...
@@ -98,8 +50,11 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
image_token
),
max_req_input_len
=
max_req_input_len
,
)
res
=
await
self
.
_process_images
(
base_output
.
images
,
base_output
.
input_text
,
max_req_input_len
res
=
self
.
process_mm_data
(
input_text
=
base_output
.
input_text
,
images
=
base_output
.
images
,
max_req_input_len
=
max_req_input_len
,
conversations
=
base_output
.
input_text
,
)
images_seq_mask
=
res
[
"images_seq_mask"
]
images_spatial_crop
=
res
[
"images_spatial_crop"
]
...
...
@@ -107,13 +62,17 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
batched_images_spatial_crop
.
append
(
images_spatial_crop
)
batched_images_spatial_crop
=
torch
.
stack
(
batched_images_spatial_crop
,
dim
=
0
)
items
=
[]
item
=
MultimodalDataItem
(
pixel_values
=
res
[
"images"
],
modality
=
Modality
.
IMAGE
,
image_emb_mask
=
images_seq_mask
,
image_spatial_crop
=
batched_images_spatial_crop
,
)
items
+=
[
item
]
return
{
"mm_items"
:
items
,
"input_ids"
:
res
[
"input_ids"
].
tolist
(),
"pixel_values"
:
res
[
"images"
],
"im_token_id"
:
res
[
"im_token_id"
],
"data_hashes"
:
base_output
.
mm_data_hashes
,
"image_sizes"
:
image_sizes
,
"images_emb_mask"
:
images_seq_mask
,
"image_spatial_crop"
:
batched_images_spatial_crop
,
"modalities"
:
request_obj
.
modalities
or
[
"image"
],
"im_token_id"
:
self
.
_processor
.
image_token_id
,
}
python/sglang/srt/managers/multimodal_processors/gemma3.py
View file @
5cb552b1
...
...
@@ -7,8 +7,8 @@ from sglang.srt.managers.multimodal_processor import (
)
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
MultimodalSpecialTokens
,
get_global_processor
,
)
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.models.gemma3_mm
import
Gemma3ForConditionalGeneration
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/image_processing_gemma3_fast.py
...
...
@@ -25,28 +25,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
self
.
IM_START_TOKEN_ID
=
hf_config
.
boi_token_index
self
.
IM_END_TOKEN_ID
=
hf_config
.
eoi_token_index
async
def
_process_single_image
(
self
,
images
,
input_text
)
->
dict
:
if
isinstance
(
images
,
list
)
and
len
(
images
)
==
0
:
images
=
None
processor
=
get_global_processor
()
result
=
processor
.
__call__
(
text
=
[
input_text
],
images
=
images
,
padding
=
True
,
return_tensors
=
"pt"
,
# if RGBA, this needs to be set
# images_kwargs={
# "input_data_format": ChannelDimension.FIRST
# }
)
pixel_values
=
getattr
(
result
,
"pixel_values"
,
None
)
return
{
"input_ids"
:
result
.
input_ids
,
"pixel_values"
:
pixel_values
,
}
async
def
process_mm_data_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
...
...
@@ -63,21 +41,28 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
image_token
=
self
.
IMAGE_TOKEN
base_output
=
self
.
load_mm_data
(
input_ids
=
input_ids
,
prompt
=
input_ids
,
image_data
=
image_data
,
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
image_token
),
max_req_input_len
=
max_req_input_len
,
discard_alpha_channel
=
True
,
)
ret
=
await
self
.
_
process_
single_image
(
ret
=
self
.
process_
mm_data
(
input_text
=
base_output
.
input_text
,
images
=
base_output
.
images
)
items
=
[]
for
i
,
image
in
enumerate
(
base_output
.
images
):
item
=
MultimodalDataItem
(
pixel_values
=
ret
[
"pixel_values"
][
i
],
modality
=
Modality
.
IMAGE
,
)
items
+=
[
item
]
return
{
"mm_items"
:
items
,
"input_ids"
:
ret
[
"input_ids"
].
flatten
().
tolist
(),
"pixel_values"
:
ret
[
"pixel_values"
],
"data_hashes"
:
base_output
.
mm_data_hashes
,
"im_start_id"
:
self
.
IM_START_TOKEN_ID
,
"im_end_id"
:
self
.
IM_END_TOKEN_ID
,
}
python/sglang/srt/managers/multimodal_processors/janus_pro.py
View file @
5cb552b1
import
asyncio
from
typing
import
List
,
Union
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
BaseMultimodalProcessor
,
MultimodalSpecialTokens
,
get_global_processor
,
)
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.models.deepseek_janus_pro
import
MultiModalityCausalLM
...
...
@@ -15,37 +14,6 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
@
staticmethod
def
_process_images_task
(
images
,
input_text
):
processor
=
get_global_processor
()
result
=
processor
.
__call__
(
prompt
=
input_text
,
images
=
images
,
return_tensors
=
"pt"
)
return
{
"input_ids"
:
result
[
"input_ids"
],
"pixel_values"
:
result
[
"pixel_values"
],
"images_emb_mask"
:
result
[
"images_emb_mask"
],
"im_start_id"
:
processor
.
image_start_id
,
"im_end_id"
:
processor
.
image_end_id
,
"im_token_id"
:
processor
.
image_id
,
}
async
def
_process_images
(
self
,
images
,
input_text
):
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
image_inputs
=
await
loop
.
run_in_executor
(
self
.
executor
,
JanusProImageProcessor
.
_process_images_task
,
images
,
input_text
,
)
else
:
image_inputs
=
self
.
_processor
(
images
=
images
,
text
=
input_text
,
return_tensors
=
"pt"
)
return
image_inputs
async
def
process_mm_data_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
...
...
@@ -60,25 +28,31 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
if
not
isinstance
(
image_data
,
list
):
image_data
=
[
image_data
]
processor
=
self
.
_processor
base_out
=
self
.
load_mm_data
(
input_ids
=
input_ids
,
prompt
=
input_ids
,
image_data
=
image_data
,
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
"<image_placeholder>"
),
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
processor
.
image_tag
),
max_req_input_len
=
max_req_input_len
,
)
images
=
base_out
.
images
res
=
await
self
.
_process_images
(
images
=
images
,
input_text
=
base_out
.
input_text
)
# print(res)
# print(base_out)
# print("", res["images_emb_mask"].shape)
res
=
self
.
process_mm_data
(
input_text
=
base_out
.
input_text
,
prompt
=
base_out
.
input_text
,
images
=
images
,
)
return
{
"mm_items"
:
[
MultimodalDataItem
(
pixel_values
=
res
[
"pixel_values"
],
image_emb_mask
=
res
[
"images_emb_mask"
],
modality
=
Modality
.
IMAGE
,
)
],
"input_ids"
:
res
[
"input_ids"
].
flatten
().
tolist
(),
"pixel_values"
:
res
[
"pixel_values"
],
"images_emb_mask"
:
res
[
"images_emb_mask"
],
"data_hashes"
:
base_out
.
mm_data_hashes
,
"im_start_id"
:
res
[
"im_start_id"
],
"im_end_id"
:
res
[
"im_end_id"
],
"im_token_id"
:
res
[
"im_token_id"
],
"im_start_id"
:
processor
.
image_start_id
,
"im_end_id"
:
processor
.
image_end_id
,
"im_token_id"
:
processor
.
image_id
,
}
python/sglang/srt/managers/multimodal_processors/llava.py
View file @
5cb552b1
...
...
@@ -5,8 +5,8 @@ import numpy as np
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
BaseMultimodalProcessor
,
get_global_processor
,
)
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.mm_utils
import
expand2square
,
process_anyres_image
from
sglang.srt.models.llava
import
LlavaMistralForCausalLM
,
LlavaQwenForCausalLM
from
sglang.srt.models.llavavid
import
LlavaVidForCausalLM
...
...
@@ -25,11 +25,10 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
image_data
:
Union
[
str
,
bytes
],
image_aspect_ratio
:
Optional
[
str
]
=
None
,
image_grid_pinpoints
:
Optional
[
str
]
=
None
,
image_
processor
=
None
,
processor
=
None
,
):
processor
=
get_global_processor
()
image_processor
=
image_processor
or
processor
.
image_processor
image_processor
=
processor
.
image_processor
try
:
image
,
image_size
=
load_image
(
image_data
)
...
...
@@ -72,18 +71,22 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
async
def
_process_single_image
(
self
,
image_data
:
Union
[
bytes
,
str
],
aspect_ratio
:
str
,
grid_pinpoints
:
str
):
if
self
.
executor
is
not
None
:
if
self
.
cpu_
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
return
await
loop
.
run_in_executor
(
self
.
executor
,
self
.
cpu_
executor
,
LlavaImageProcessor
.
_process_single_image_task
,
image_data
,
aspect_ratio
,
grid_pinpoints
,
self
.
_processor
,
)
else
:
return
self
.
_process_single_image_task
(
image_data
,
aspect_ratio
,
grid_pinpoints
image_data
,
aspect_ratio
,
grid_pinpoints
,
self
.
_processor
.
image_processor
,
)
async
def
process_mm_data_async
(
...
...
@@ -134,14 +137,22 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
pixel_values
,
image_hash
,
image_size
=
await
self
.
_process_single_image
(
image_data
[
0
],
aspect_ratio
,
grid_pinpoints
)
data_hashes
=
[
image_hash
]
image_sizes
=
[
image_size
]
else
:
raise
ValueError
(
f
"Invalid image data:
{
image_data
}
"
)
modality
=
Modality
.
IMAGE
if
isinstance
(
request_obj
.
modalities
,
list
):
if
request_obj
.
modalities
[
0
]
==
"multi-images"
:
modality
=
Modality
.
MULTI_IMAGES
elif
request_obj
.
modalities
[
0
]
==
"video"
:
modality
=
Modality
.
VIDEO
return
{
"pixel_values"
:
pixel_values
,
"data_hashes"
:
data_hashes
,
"image_sizes"
:
image_sizes
,
"modalities"
:
request_obj
.
modalities
or
[
"image"
],
"mm_items"
:
[
MultimodalDataItem
(
pixel_values
=
pixel_values
,
image_sizes
=
image_sizes
,
modality
=
modality
,
)
],
}
python/sglang/srt/managers/multimodal_processors/minicpm.py
View file @
5cb552b1
import
asyncio
from
typing
import
List
,
Union
import
torch
from
transformers
import
BaseImageProcessorFast
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
BaseMultimodalProcessor
,
MultimodalSpecialTokens
,
get_global_processor
,
)
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.models.minicpmo
import
MiniCPMO
from
sglang.srt.models.minicpmv
import
MiniCPMV
...
...
@@ -21,19 +21,23 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
self
.
image_token
=
"(<image>./</image>)"
self
.
audio_token
=
"(<audio>./</audio>)"
@
staticmethod
def
_process_data_task
(
input_text
,
images
=
None
,
audios
=
None
):
def
process_data_task
(
self
,
input_text
,
images
=
None
,
audios
=
None
):
if
isinstance
(
images
,
list
)
and
len
(
images
)
==
0
:
images
=
None
if
isinstance
(
audios
,
list
)
and
len
(
audios
)
==
0
:
audios
=
None
result
=
get_global_processor
().
__call__
(
processor
=
self
.
_processor
args
=
{}
if
isinstance
(
processor
,
BaseImageProcessorFast
):
args
[
"device"
]
=
"cuda"
result
=
self
.
_processor
.
__call__
(
text
=
input_text
,
images
=
images
,
audios
=
audios
,
return_tensors
=
"pt"
,
chunk_input
=
True
,
**
args
,
)
return
{
"input_ids"
:
result
.
input_ids
,
...
...
@@ -44,23 +48,6 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
"audio_bounds"
:
getattr
(
result
,
"audio_bounds"
,
None
),
}
async
def
_process_data
(
self
,
images
,
input_text
,
audios
=
None
):
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
multimodal_data_inputs
=
await
loop
.
run_in_executor
(
self
.
executor
,
MiniCPMMultimodalProcessor
.
_process_data_task
,
input_text
,
images
,
audios
,
)
else
:
multimodal_data_inputs
=
self
.
_processor
(
images
=
images
,
text
=
input_text
,
audios
=
audios
,
return_tensors
=
"pt"
)
return
multimodal_data_inputs
async
def
process_mm_data_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
...
...
@@ -77,7 +64,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
audio_data
=
[
audio_data
]
base_output
=
self
.
load_mm_data
(
input_ids
=
input_ids
,
prompt
=
input_ids
,
max_req_input_len
=
max_req_input_len
,
audio_data
=
audio_data
,
image_data
=
image_data
,
...
...
@@ -88,9 +75,9 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
if
base_output
is
None
:
return
None
res
=
await
self
.
_process_data
(
images
=
base_output
.
images
,
res
=
self
.
process_mm_data
(
input_text
=
base_output
.
input_text
,
images
=
base_output
.
images
,
audios
=
base_output
.
audios
,
)
...
...
@@ -142,23 +129,33 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
tgt_sizes_flat
+=
[
tgt_n
]
pixel_values
=
pixel_values_flat
if
len
(
tgt_sizes_flat
)
==
0
:
tgt_sizes
=
None
else
:
tgt_sizes
=
torch
.
stack
(
tgt_sizes_flat
)
if
not
isinstance
(
res
[
"audio_features"
],
list
):
res
[
"audio_features"
]
=
[
res
[
"audio_features"
]]
items
=
[]
if
len
(
pixel_values
)
!=
0
:
item
=
MultimodalDataItem
(
pixel_values
=
pixel_values
,
tgt_size
=
tgt_sizes_flat
,
modality
=
Modality
.
IMAGE
,
)
items
+=
[
item
]
if
(
"audio_features"
in
res
and
res
[
"audio_features"
]
is
not
None
and
len
(
res
[
"audio_features"
])
!=
0
):
item
=
MultimodalDataItem
(
audio_features
=
[
res
[
"audio_features"
]],
audio_feature_lens
=
res
[
"audio_feature_lens"
],
modality
=
Modality
.
AUDIO
,
)
items
+=
[
item
]
return
{
"mm_items"
:
items
,
"input_ids"
:
res
[
"input_ids"
].
flatten
().
tolist
(),
"pixel_values"
:
pixel_values
,
"tgt_sizes"
:
tgt_sizes
,
"data_hashes"
:
base_output
.
mm_data_hashes
,
"modalities"
:
request_obj
.
modalities
or
[
"image"
],
"audio_start_id"
:
audio_start_id
,
"audio_end_id"
:
audio_end_id
,
"audio_features"
:
res
[
"audio_features"
],
"audio_bounds"
:
res
[
"audio_bounds"
],
"audio_feature_lens"
:
res
[
"audio_feature_lens"
],
"im_token_id"
:
im_token_id
,
"im_start_id"
:
tokenizer
.
im_start_id
,
"im_end_id"
:
tokenizer
.
im_end_id
,
...
...
python/sglang/srt/managers/multimodal_processors/mlama.py
View file @
5cb552b1
import
asyncio
from
typing
import
List
,
Union
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
BaseMultimodalProcessor
,
get_global_processor
,
)
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.models.mllama
import
MllamaForConditionalGeneration
from
sglang.srt.utils
import
load_image
...
...
@@ -15,25 +14,6 @@ class MllamaImageProcessor(BaseMultimodalProcessor):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
@
staticmethod
def
_process_single_image_task
(
images
,
input_text
):
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
return
get_global_processor
()(
images
,
input_text
,
return_tensors
=
"pt"
)
async
def
_process_single_image
(
self
,
images
,
input_text
):
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
image_inputs
=
await
loop
.
run_in_executor
(
self
.
executor
,
MllamaImageProcessor
.
_process_single_image_task
,
images
,
input_text
,
)
else
:
image_inputs
=
self
.
_processor
(
images
,
input_text
,
return_tensors
=
"pt"
)
return
image_inputs
async
def
process_mm_data_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_text
,
*
args
,
**
kwargs
):
...
...
@@ -52,8 +32,15 @@ class MllamaImageProcessor(BaseMultimodalProcessor):
else
:
images
=
load_image
(
image_data
[
0
])[
0
]
image_inputs
=
await
self
.
_process_single_image
(
images
,
input_text
)
image_inputs
[
"data_hashes"
]
=
[
hash
(
str
(
image_data
))]
image_inputs
=
self
.
process_mm_data
(
input_text
=
input_text
,
images
=
images
)
image_inputs
[
"input_ids"
]
=
image_inputs
[
"input_ids"
].
tolist
()[
0
]
image_inputs
[
"mm_items"
]
=
[
MultimodalDataItem
(
pixel_values
=
image_inputs
[
"pixel_values"
],
aspect_ratio_id
=
image_inputs
[
"aspect_ratio_ids"
],
aspect_ratio_mask
=
image_inputs
[
"aspect_ratio_mask"
],
modality
=
Modality
.
IMAGE
,
)
]
return
image_inputs
python/sglang/srt/managers/multimodal_processors/qwen_vl.py
View file @
5cb552b1
import
asyncio
import
math
import
time
from
typing
import
List
,
Union
import
torch
...
...
@@ -11,8 +10,8 @@ from sglang.srt.managers.multimodal_processor import (
)
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
MultimodalSpecialTokens
,
get_global_processor
,
)
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.models.qwen2_5_vl
import
Qwen2_5_VLForConditionalGeneration
from
sglang.srt.models.qwen2_vl
import
Qwen2VLForConditionalGeneration
...
...
@@ -34,45 +33,15 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
self
.
MAX_PIXELS
=
16384
*
28
*
28
self
.
MAX_RATIO
=
200
@
staticmethod
def
_process_images_task
(
images
,
input_text
,
_hf_config
):
if
isinstance
(
images
,
list
)
and
len
(
images
)
==
0
:
images
=
None
result
=
get_global_processor
().
__call__
(
text
=
[
input_text
],
images
=
images
,
padding
=
True
,
return_tensors
=
"pt"
)
return
{
"input_ids"
:
result
.
input_ids
,
"pixel_values"
:
getattr
(
result
,
"pixel_values"
,
None
),
"image_grid_thw"
:
getattr
(
result
,
"image_grid_thw"
,
None
),
"second_per_grid_ts"
:
getattr
(
result
,
"second_per_grid_ts"
,
None
),
"video_grid_thws"
:
getattr
(
result
,
"video_grid_thws"
,
None
),
}
async
def
_process_single_image
(
self
,
images
,
input_text
)
->
dict
:
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
return
await
loop
.
run_in_executor
(
self
.
executor
,
Qwen2_5VLImageProcessor
.
_process_images_task
,
images
,
input_text
,
self
.
hf_config
,
)
else
:
return
self
.
_process_images_task
(
images
,
input_text
,
self
.
hf_config
)
async
def
process_mm_data_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_ids
,
prompt
,
request_obj
,
max_req_input_len
,
*
args
,
**
kwargs
,
):
start
=
time
.
time
()
if
not
image_data
:
return
None
if
isinstance
(
image_data
,
str
):
...
...
@@ -80,7 +49,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
image_token
=
self
.
IMAGE_TOKEN
base_output
=
self
.
load_mm_data
(
input_ids
=
input_ids
,
prompt
=
prompt
,
image_data
=
image_data
,
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
image_token
),
max_req_input_len
=
max_req_input_len
,
...
...
@@ -144,24 +113,32 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return
math
.
floor
(
number
/
factor
)
*
factor
images
=
[
resize_image
(
image
)
for
image
in
base_output
.
images
]
async
def
resize_image_async
(
image
):
return
resize_image
(
image
)
ret
=
await
self
.
_process_single_image
(
images
=
images
,
input_text
=
base_output
.
input_text
resize_tasks
=
[
resize_image_async
(
image
)
for
image
in
base_output
.
images
]
resized_images
=
await
asyncio
.
gather
(
*
resize_tasks
)
ret
=
self
.
process_mm_data
(
input_text
=
base_output
.
input_text
,
images
=
resized_images
,
)
image_grid_thws
=
torch
.
concat
([
ret
[
"image_grid_thw"
]])
video_grid_thws
=
None
return
{
"input_ids"
:
ret
[
"input_ids"
].
flatten
().
tolist
(),
"pixel_values"
:
ret
[
"pixel_values"
],
"data_hashes"
:
base_output
.
mm_data_hashes
,
"modalities"
:
request_obj
.
modalities
or
[
"image"
],
"image_grid_thws"
:
image_grid_thws
,
"video_grid_thws"
:
video_grid_thws
,
"mm_items"
:
[
MultimodalDataItem
(
pixel_values
=
ret
[
"pixel_values"
],
image_grid_thws
=
image_grid_thws
,
# TODO
video_grid_thws
=
None
,
second_per_grid_ts
=
ret
.
get
(
"second_per_grid_ts"
,
None
),
modality
=
Modality
.
IMAGE
,
)
],
"im_start_id"
:
self
.
IM_START_TOKEN_ID
,
"im_end_id"
:
self
.
IM_END_TOKEN_ID
,
"im_token_id"
:
self
.
image_token_id
,
"video_token_id"
:
self
.
video_token_id
,
"second_per_grid_ts"
:
ret
[
"second_per_grid_ts"
],
}
python/sglang/srt/managers/schedule_batch.py
View file @
5cb552b1
from
__future__
import
annotations
from
enum
import
Enum
,
auto
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -51,7 +53,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
get_compiler_backend
from
sglang.srt.utils
import
flatten_nested_list
,
get_compiler_backend
if
TYPE_CHECKING
:
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
...
...
@@ -143,165 +145,185 @@ class FINISH_ABORT(BaseFinishReason):
}
class
Modality
(
Enum
):
IMAGE
=
auto
()
MULTI_IMAGES
=
auto
()
VIDEO
=
auto
()
AUDIO
=
auto
()
@
dataclasses
.
dataclass
class
MultimodalInputs
:
"""The image related inputs."""
class
MultimodalDataItem
:
"""
A single multimodal data, from a single image/video/audio or other
"""
modality
:
Modality
hash
:
int
=
None
pad_value
:
int
=
None
aspect_ratio_id
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
aspect_ratio_mask
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
pixel_values
:
Union
[
torch
.
Tensor
,
np
.
array
]
data_hashes
:
Optional
[
list
]
=
None
image_sizes
:
Optional
[
list
]
=
None
image_sizes
:
Tuple
[
int
,
int
]
=
None
image_offsets
:
Optional
[
list
]
=
None
# the real data, pixel_values or audio_features
# data: Union[List[torch.Tensor], List[np.array]]
pixel_values
:
Union
[
torch
.
Tensor
,
np
.
array
]
=
None
image_grid_thws
:
Union
[
torch
.
Tensor
,
np
.
array
]
=
None
video_grid_thws
:
Union
[
torch
.
Tensor
,
np
.
array
]
=
None
image_emb_mask
:
Optional
[
torch
.
Tensor
]
=
None
image_spatial_crop
:
Optional
[
torch
.
Tensor
]
=
None
second_per_grid_ts
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
# [num_images, (n, w, h)]
tgt_size
:
Tuple
[
int
,
int
]
=
None
audio_features
:
Union
[
torch
.
Tensor
,
np
.
array
]
=
None
audio_feature_lens
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
@
staticmethod
def
is_empty_list
(
l
):
if
l
is
None
:
return
True
return
len
([
item
for
item
in
flatten_nested_list
(
l
)
if
item
is
not
None
])
==
0
def
set_pad_value
(
self
):
"""
Set the pad value after first hashign the data
"""
def
hash_feature
(
f
):
if
isinstance
(
f
,
list
):
return
hash
(
tuple
(
flatten_nested_list
(
f
)))
elif
isinstance
(
f
,
np
.
ndarray
):
arr
=
np
.
ascontiguousarray
(
f
)
arr_bytes
=
arr
.
tobytes
()
return
hash
(
arr_bytes
)
return
hash
(
f
)
if
self
.
is_audio
():
self
.
hash
=
hash_feature
(
self
.
audio_features
)
else
:
self
.
hash
=
hash_feature
(
self
.
pixel_values
)
assert
self
.
hash
is
not
None
self
.
pad_value
=
self
.
hash
%
(
1
<<
30
)
def
is_audio
(
self
):
return
(
self
.
modality
==
Modality
.
AUDIO
)
and
not
MultimodalDataItem
.
is_empty_list
(
self
.
audio_features
)
def
is_image
(
self
):
return
(
self
.
modality
==
Modality
.
IMAGE
or
self
.
modality
==
Modality
.
MULTI_IMAGES
)
and
not
MultimodalDataItem
.
is_empty_list
(
self
.
pixel_values
)
def
is_video
(
self
):
return
(
self
.
modality
==
Modality
.
VIDEO
)
and
not
MultimodalDataItem
.
is_empty_list
(
self
.
pixel_values
)
def
validate
(
self
):
...
# TODO
@
dataclasses
.
dataclass
class
MultimodalInputs
:
"""The multimodal data related inputs."""
# items of data
mm_items
:
List
[
MultimodalDataItem
]
image_pad_len
:
Optional
[
list
]
=
None
pad_values
:
Optional
[
list
]
=
None
modalities
:
Optional
[
list
]
=
None
num_image_tokens
:
Optional
[
int
]
=
None
# Llava related
aspect_ratio_ids
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
aspect_ratio_mask
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
# QWen2-VL related
# [num_of_images, t, h, w]
image_grid_thws
:
torch
.
Tensor
=
None
mrope_position_delta
:
Optional
[
torch
.
Tensor
]
=
None
# Qwen2-VL video related
video_token_id
:
Optional
[
int
]
=
None
video_grid_thws
:
List
[
Tuple
[
int
,
int
,
int
]]
=
None
second_per_grid_ts
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
# deepseek vl2 related
images_emb_mask
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
image_spatial_crop
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
# The id of the single-image placeholder token
# image
im_token_id
:
Optional
[
torch
.
Tensor
]
=
None
# All the images in the batch should share the same special image
# bound token ids.
im_start_id
:
Optional
[
int
]
=
None
im_end_id
:
Optional
[
int
]
=
None
slice_start_id
:
Optional
[
int
]
=
None
slice_end_id
:
Optional
[
int
]
=
None
# [num_images, 2 (w, h)]
tgt_sizes
:
Optional
[
list
]
=
None
# video
video_token_id
:
Optional
[
int
]
=
None
# audio
audio_start_id
:
Optional
[
torch
.
Tensor
]
=
None
audio_end_id
:
Optional
[
torch
.
Tensor
]
=
None
audio_features
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
audio_feature_lens
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
@
staticmethod
def
from_dict
(
obj
:
dict
):
ret
=
MultimodalInputs
(
pixel_values
=
obj
[
"pixel_values"
],
data_hashes
=
obj
[
"data_hashes"
],
mm_items
=
obj
[
"mm_items"
],
)
assert
isinstance
(
ret
.
mm_items
,
list
)
ret
.
mm_items
=
[
item
for
item
in
ret
.
mm_items
if
item
.
is_audio
()
or
item
.
is_image
()
or
item
.
is_video
()
]
assert
len
(
ret
.
mm_items
)
!=
0
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Please note that if the `input_ids` is later used in the model forward,
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
# errors in cuda kernels. See also llava.py for example.
ret
.
pad_values
=
[
x
%
(
1
<<
30
)
for
x
in
ret
.
data_hashes
]
for
item
in
ret
.
mm_items
:
item
.
set_pad_value
()
optional_args
=
[
"image_sizes"
,
"modalities"
,
"aspect_ratio_ids"
,
"aspect_ratio_mask"
,
"image_grid_thws"
,
"images_emb_mask"
,
"image_spatial_crop"
,
"im_token_id"
,
"im_start_id"
,
"im_end_id"
,
"slice_start_id"
,
"slice_end_id"
,
"tgt_sizes"
,
"audio_start_id"
,
"audio_end_id"
,
"audio_features"
,
"audio_feature_lens"
,
]
for
arg
in
optional_args
:
if
arg
in
obj
:
setattr
(
ret
,
arg
,
obj
[
arg
])
# validate
assert
(
isinstance
(
ret
.
pixel_values
,
torch
.
Tensor
)
or
isinstance
(
ret
.
pixel_values
,
np
.
ndarray
)
or
isinstance
(
ret
.
pixel_values
,
list
)
)
assert
ret
.
audio_features
is
None
or
isinstance
(
ret
.
audio_features
,
list
)
return
ret
def
contains_image_inputs
(
self
)
->
bool
:
""" """
return
self
.
pixel_values
is
not
None
and
self
.
pixel_values
!=
[]
return
any
(
item
.
is_image
()
for
item
in
self
.
mm_items
)
def
contains_audio_inputs
(
self
)
->
bool
:
""" """
return
self
.
audio_features
is
not
None
and
self
.
audio_features
!=
[]
return
any
(
item
.
is_audio
()
for
item
in
self
.
mm_items
)
def
collect_image_inputs
(
self
)
->
List
[
torch
.
Tensor
]:
return
[
item
.
pixel_values
for
item
in
self
.
mm_items
if
item
.
is_image
()]
def
merge
(
self
,
other
:
MultimodalInputs
):
"""
merge image inputs when requests are being merged
"""
if
isinstance
(
self
.
pixel_values
,
list
):
# in some rare cases, pixel values are list of patches with different shapes
# e.g. minicpm
self
.
pixel_values
+=
other
.
pixel_values
else
:
assert
(
self
.
pixel_values
.
shape
[
1
:]
==
other
.
pixel_values
.
shape
[
1
:]
),
f
"
{
self
.
pixel_values
.
shape
[
1
:]
}
vs
{
other
.
pixel_values
.
shape
[
1
:]
}
"
self
.
pixel_values
=
np
.
concatenate
([
self
.
pixel_values
,
other
.
pixel_values
])
# args would be stacked along first dim
# usually these are already tensors
stack_args
=
[
# TODO: merge with image_grid_thws, basically the same thing
"tgt_sizes"
,
"image_spatial_crop"
,
]
for
arg
in
stack_args
:
if
getattr
(
self
,
arg
,
None
)
is
None
:
setattr
(
self
,
arg
,
getattr
(
other
,
arg
,
None
))
elif
getattr
(
other
,
arg
,
None
)
is
not
None
:
# self and other both not None
setattr
(
self
,
arg
,
torch
.
cat
([
getattr
(
self
,
arg
),
getattr
(
other
,
arg
)],
dim
=
0
),
)
if
self
.
image_grid_thws
is
None
:
self
.
image_grid_thws
=
other
.
image_grid_thws
elif
other
.
image_grid_thws
is
not
None
:
self
.
image_grid_thws
=
torch
.
concat
(
[
self
.
image_grid_thws
,
other
.
image_grid_thws
]
)
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Please note that if the `input_ids` is later used in the model forward,
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
# errors in cuda kernels. See also llava.py for example.
self
.
data_hashes
+=
other
.
data_hashes
self
.
pad_values
=
[
x
%
(
1
<<
30
)
for
x
in
self
.
data_hashes
]
# args needed to be merged
optional_args
=
[
"audio_features"
,
"image_sizes"
,
"items"
,
"image_offsets"
,
"image_pad_len"
,
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
"aspect_ratio_ids"
,
"aspect_ratio_mask"
,
"images_emb_mask"
,
]
for
arg
in
optional_args
:
self_arg
=
getattr
(
self
,
arg
,
None
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
5cb552b1
...
...
@@ -112,7 +112,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
from
sglang.srt.mem_cache.hiradix_cache
import
HiRadixCache
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.metrics.collector
import
SchedulerMetricsCollector
,
SchedulerStats
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.torch_memory_saver_adapter
import
TorchMemorySaverAdapter
...
...
python/sglang/srt/managers/utils.py
View file @
5cb552b1
import
json
import
logging
import
time
from
collections
import
defaultdict
from
http
import
HTTPStatus
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
from
typing
import
Optional
from
sglang.srt.managers.schedule_batch
import
FINISH_ABORT
,
Req
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
5cb552b1
...
...
@@ -355,11 +355,6 @@ class ForwardBatch:
for
mm_input
in
valid_inputs
[
1
:]:
merged
.
merge
(
mm_input
)
if
isinstance
(
merged
.
pixel_values
,
np
.
ndarray
):
merged
.
pixel_values
=
torch
.
from_numpy
(
merged
.
pixel_values
)
if
isinstance
(
merged
.
audio_features
,
np
.
ndarray
):
merged
.
audio_features
=
torch
.
from_numpy
(
merged
.
audio_features
)
return
merged
def
contains_image_inputs
(
self
)
->
bool
:
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
5cb552b1
...
...
@@ -251,17 +251,16 @@ class ModelRunner:
self
.
init_double_sparsity_channel_config
(
server_args
.
ds_heavy_channel_type
)
if
self
.
is_multimodal
:
self
.
mem_fraction_static
*=
0.9
5
self
.
mem_fraction_static
*=
0.9
0
logger
.
info
(
f
"Automatically reduce --mem-fraction-static to
{
self
.
mem_fraction_static
:.
3
f
}
"
f
"because this is a multimodal model."
)
if
self
.
model_config
.
hf_config
.
architectures
==
[
"MllamaForConditionalGeneration"
]:
logger
.
info
(
"Automatically turn off --chunked-prefill-size for mllama."
)
server_args
.
chunked_prefill_size
=
-
1
logger
.
info
(
"Automatically turn off --chunked-prefill-size for multimodal model."
)
server_args
.
chunked_prefill_size
=
-
1
if
self
.
model_config
.
hf_config
.
architectures
==
[
"Qwen2VLForConditionalGeneration"
...
...
@@ -269,18 +268,7 @@ class ModelRunner:
"Qwen2_5_VLForConditionalGeneration"
]:
# TODO: qwen2-vl series does not support radix cache now, set disable_radix_cache=True automatically
logger
.
info
(
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen-vl series."
)
server_args
.
chunked_prefill_size
=
-
1
server_args
.
disable_radix_cache
=
True
if
self
.
model_config
.
hf_config
.
architectures
==
[
"DeepseekVL2ForCausalLM"
]:
# TODO: deepseek-vl2 does not support radix cache now, set disable_radix_cache=True automatically
logger
.
info
(
"Automatically turn off --chunked-prefill-size and disable radix cache for deepseek-vl2."
)
server_args
.
chunked_prefill_size
=
-
1
logger
.
info
(
"Automatically disable radix cache for qwen-vl series."
)
server_args
.
disable_radix_cache
=
True
if
server_args
.
enable_deepep_moe
:
...
...
python/sglang/srt/models/clip.py
View file @
5cb552b1
...
...
@@ -17,7 +17,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from
sglang.srt.managers.schedule_batch
import
MultimodalInputs
from
sglang.srt.model_executor.model_runner
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
add_prefix
from
sglang.srt.utils
import
add_prefix
,
flatten_nested_list
class
CLIPVisionEmbeddings
(
nn
.
Module
):
...
...
@@ -368,7 +368,6 @@ class CLIPVisionTransformer(nn.Module):
self
,
pixel_values
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embeddings
(
pixel_values
.
to
(
self
.
device
))
hidden_states
=
self
.
pre_layrnorm
(
hidden_states
)
...
...
@@ -456,12 +455,18 @@ class CLIPModel(nn.Module):
get_embedding
:
bool
=
True
,
):
assert
get_embedding
,
"CLIPEmbeddingModel is only used for embedding"
image
_inputs
=
None
mm
_inputs
=
[]
if
forward_batch
.
mm_inputs
is
not
None
:
image_inputs
=
forward_batch
.
mm_inputs
if
image_inputs
is
not
None
and
image_inputs
[
0
]
is
not
None
:
vision_outputs
=
self
.
vision_model
(
image_inputs
[
0
].
pixel_values
)
mm_inputs
=
forward_batch
.
mm_inputs
pixel_values_list
=
[
item
.
pixel_values
for
item
in
flatten_nested_list
(
[
mm_input
.
mm_items
for
mm_input
in
mm_inputs
if
mm_input
is
not
None
]
)
]
if
len
(
pixel_values_list
)
!=
0
:
pixel_values
=
torch
.
concat
(
pixel_values_list
)
vision_outputs
=
self
.
vision_model
(
pixel_values
)
pooled_output
=
vision_outputs
[:,
0
,
:]
image_embeds
=
self
.
visual_projection
(
pooled_output
)
image_embeds
=
nn
.
functional
.
normalize
(
image_embeds
,
p
=
2
,
dim
=
1
)
...
...
python/sglang/srt/models/deepseek_janus_pro.py
View file @
5cb552b1
...
...
@@ -51,7 +51,7 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs
,
general_mm_embed_routine
,
)
from
sglang.srt.managers.schedule_batch
import
Multimodal
Inputs
,
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
Multimodal
DataItem
,
MultimodalInputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.llama
import
LlamaForCausalLM
...
...
@@ -1959,8 +1959,8 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
def
get_image_feature
(
self
,
i
mage_input
:
Multimodal
Inputs
)
->
torch
.
Tensor
:
pixel_values
=
image_input
.
pixel_values
def
get_image_feature
(
self
,
i
tems
:
List
[
Multimodal
DataItem
]
)
->
torch
.
Tensor
:
pixel_values
=
torch
.
concat
([
item
.
pixel_values
for
item
in
items
],
dim
=
0
)
bs
,
n
=
pixel_values
.
shape
[
0
:
2
]
pixel_values
=
pixel_values
.
to
(
device
=
self
.
vision_model
.
device
,
dtype
=
self
.
vision_model
.
dtype
...
...
@@ -1976,7 +1976,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
return
images_embeds
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
return
self
.
language_model
.
model
.
embed_tokens
return
self
.
language_model
.
get_input_embeddings
()
@
torch
.
no_grad
()
def
forward
(
...
...
@@ -1984,23 +1984,18 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
get_embedding
:
bool
=
False
,
)
->
torch
.
Tensor
:
inputs_embeds
=
general_mm_embed_routine
(
hidden_states
=
general_mm_embed_routine
(
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
embed_tokens
=
self
.
get_input_embeddings
(),
mm_data_embedding_func
=
self
.
get_image_feature
,
)
return
self
.
language_model
(
input_ids
=
None
,
image_data_embedding_func
=
self
.
get_image_feature
,
language_model
=
self
.
language_model
,
positions
=
positions
,
forward_batch
=
forward_batch
,
input_embeds
=
inputs_embeds
,
get_embedding
=
False
,
)
return
hidden_states
def
prepare_gen_img_embeds
(
self
,
image_ids
:
torch
.
LongTensor
):
return
self
.
gen_aligner
(
self
.
gen_embed
(
image_ids
))
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment