Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5cb552b1
Unverified
Commit
5cb552b1
authored
Apr 01, 2025
by
Mick
Committed by
GitHub
Mar 31, 2025
Browse files
refactor: multimodal data (#4754)
parent
c7457191
Changes
36
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
564 additions
and
662 deletions
+564
-662
benchmark/mmmu/bench_hf.py
benchmark/mmmu/bench_hf.py
+32
-11
benchmark/mmmu/eval_utils.py
benchmark/mmmu/eval_utils.py
+2
-0
python/sglang/srt/managers/mm_utils.py
python/sglang/srt/managers/mm_utils.py
+202
-156
python/sglang/srt/managers/multimodal_processor.py
python/sglang/srt/managers/multimodal_processor.py
+0
-2
python/sglang/srt/managers/multimodal_processors/base_processor.py
...lang/srt/managers/multimodal_processors/base_processor.py
+36
-72
python/sglang/srt/managers/multimodal_processors/clip.py
python/sglang/srt/managers/multimodal_processors/clip.py
+7
-26
python/sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py
...lang/srt/managers/multimodal_processors/deepseek_vl_v2.py
+17
-58
python/sglang/srt/managers/multimodal_processors/gemma3.py
python/sglang/srt/managers/multimodal_processors/gemma3.py
+12
-27
python/sglang/srt/managers/multimodal_processors/janus_pro.py
...on/sglang/srt/managers/multimodal_processors/janus_pro.py
+21
-47
python/sglang/srt/managers/multimodal_processors/llava.py
python/sglang/srt/managers/multimodal_processors/llava.py
+23
-12
python/sglang/srt/managers/multimodal_processors/minicpm.py
python/sglang/srt/managers/multimodal_processors/minicpm.py
+35
-38
python/sglang/srt/managers/multimodal_processors/mlama.py
python/sglang/srt/managers/multimodal_processors/mlama.py
+10
-23
python/sglang/srt/managers/multimodal_processors/qwen_vl.py
python/sglang/srt/managers/multimodal_processors/qwen_vl.py
+21
-44
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+116
-94
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+1
-1
python/sglang/srt/managers/utils.py
python/sglang/srt/managers/utils.py
+1
-6
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+0
-5
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+6
-18
python/sglang/srt/models/clip.py
python/sglang/srt/models/clip.py
+12
-7
python/sglang/srt/models/deepseek_janus_pro.py
python/sglang/srt/models/deepseek_janus_pro.py
+10
-15
No files found.
benchmark/mmmu/bench_hf.py
View file @
5cb552b1
...
@@ -72,7 +72,8 @@ def eval_mmmu(args):
...
@@ -72,7 +72,8 @@ def eval_mmmu(args):
if
suffix
:
if
suffix
:
contents
+=
[{
"type"
:
"text"
,
"text"
:
suffix
}]
contents
+=
[{
"type"
:
"text"
,
"text"
:
suffix
}]
messages
=
[{
"role"
:
"user"
,
"content"
:
contents
}]
messages
=
[{
"role"
:
"user"
,
"content"
:
contents
}]
model_inputs
=
processor
.
apply_chat_template
(
try
:
model_inputs
=
processor
.
tokenizer
.
apply_chat_template
(
messages
,
messages
,
tokenize
=
True
,
tokenize
=
True
,
return_dict
=
True
,
return_dict
=
True
,
...
@@ -80,9 +81,29 @@ def eval_mmmu(args):
...
@@ -80,9 +81,29 @@ def eval_mmmu(args):
return_tensors
=
"pt"
,
return_tensors
=
"pt"
,
).
to
(
model
.
device
)
).
to
(
model
.
device
)
input_len
=
model_inputs
[
"input_ids"
].
shape
[
-
1
]
input_len
=
model_inputs
[
"input_ids"
].
shape
[
-
1
]
generation
=
model
.
generate
(
**
model_inputs
,
generation_config
=
generation_config
)
generation
=
model
.
generate
(
**
model_inputs
,
generation_config
=
generation_config
)
generation
=
generation
[
0
][
input_len
:]
generation
=
generation
[
0
][
input_len
:]
response
=
processor
.
decode
(
generation
,
skip_special_tokens
=
True
)
response
=
processor
.
decode
(
generation
,
skip_special_tokens
=
True
)
except
:
contents
=
[]
if
prefix
:
contents
+=
[
prefix
]
image
=
PIL
.
Image
.
open
(
sample
[
"image_path"
])
contents
+=
[
image
]
if
suffix
:
contents
+=
[
suffix
]
messages
=
[{
"role"
:
"user"
,
"content"
:
contents
}]
response
=
model
.
chat
(
msgs
=
messages
,
tokenizer
=
processor
.
tokenizer
,
sampling
=
False
,
max_new_tokens
=
sampling_params
[
"max_new_tokens"
],
use_tts_template
=
False
,
generate_audio
=
False
,
temperature
=
0.0
,
)
print
(
f
"response:
{
response
}
"
)
print
(
f
"response:
{
response
}
"
)
process_result
(
response
,
sample
,
answer_dict
,
out_samples
)
process_result
(
response
,
sample
,
answer_dict
,
out_samples
)
...
...
benchmark/mmmu/eval_utils.py
View file @
5cb552b1
...
@@ -442,6 +442,8 @@ def calculate_ins_level_acc(results: Dict):
...
@@ -442,6 +442,8 @@ def calculate_ins_level_acc(results: Dict):
def
process_result
(
response
,
sample
,
answer_dict
,
out_samples
):
def
process_result
(
response
,
sample
,
answer_dict
,
out_samples
):
if
response
is
None
:
return
if
sample
[
"question_type"
]
==
"multiple-choice"
:
if
sample
[
"question_type"
]
==
"multiple-choice"
:
pred_ans
=
parse_multi_choice_response
(
pred_ans
=
parse_multi_choice_response
(
response
,
sample
[
"all_choices"
],
sample
[
"index2ans"
]
response
,
sample
[
"all_choices"
],
sample
[
"index2ans"
]
...
...
python/sglang/srt/managers/mm_utils.py
View file @
5cb552b1
This diff is collapsed.
Click to expand it.
python/sglang/srt/managers/multimodal_processor.py
View file @
5cb552b1
...
@@ -64,5 +64,3 @@ def get_mm_processor(
...
@@ -64,5 +64,3 @@ def get_mm_processor(
f
"No processor registered for architecture:
{
hf_config
.
architectures
}
.
\n
"
f
"No processor registered for architecture:
{
hf_config
.
architectures
}
.
\n
"
f
"Registered architectures:
{
[
model_cls
.
__name__
for
model_cls
in
PROCESSOR_MAPPING
.
keys
()]
}
"
f
"Registered architectures:
{
[
model_cls
.
__name__
for
model_cls
in
PROCESSOR_MAPPING
.
keys
()]
}
"
)
)
self
.
image_proce
python/sglang/srt/managers/multimodal_processors/base_processor.py
View file @
5cb552b1
...
@@ -8,18 +8,10 @@ from typing import Optional
...
@@ -8,18 +8,10 @@ from typing import Optional
import
numpy
as
np
import
numpy
as
np
import
PIL
import
PIL
import
transformers
from
decord
import
VideoReader
,
cpu
from
decord
import
VideoReader
,
cpu
from
PIL
import
Image
from
PIL
import
Image
from
sglang.srt.utils
import
load_audio
,
load_image
,
logger
from
sglang.srt.utils
import
encode_video
,
load_audio
,
load_image
,
logger
global
global_processor
def
get_global_processor
():
global
global_processor
return
global_processor
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -27,9 +19,6 @@ class BaseMultiModalProcessorOutput:
...
@@ -27,9 +19,6 @@ class BaseMultiModalProcessorOutput:
# input_text, with each frame of video/image represented with a image_token
# input_text, with each frame of video/image represented with a image_token
input_text
:
str
input_text
:
str
mm_data_hashes
:
Optional
[
list
[
int
]]
# images
image_sizes
:
Optional
[
list
[
int
]]
# frames loaded from image and video, in given order
# frames loaded from image and video, in given order
images
:
Optional
[
list
[
PIL
.
Image
]]
=
None
images
:
Optional
[
list
[
PIL
.
Image
]]
=
None
...
@@ -37,7 +26,7 @@ class BaseMultiModalProcessorOutput:
...
@@ -37,7 +26,7 @@ class BaseMultiModalProcessorOutput:
audios
:
Optional
[
list
[
np
.
ndarray
]]
=
None
audios
:
Optional
[
list
[
np
.
ndarray
]]
=
None
def
normalize
(
self
):
def
normalize
(
self
):
for
field_name
in
[
"data_hashes"
,
"image_sizes"
,
"images"
,
"audios"
]:
for
field_name
in
[
"image_sizes"
,
"images"
,
"audios"
]:
field
=
getattr
(
self
,
field_name
,
None
)
field
=
getattr
(
self
,
field_name
,
None
)
if
field
is
not
None
and
isinstance
(
field
,
list
)
and
len
(
field
)
==
0
:
if
field
is
not
None
and
isinstance
(
field
,
list
)
and
len
(
field
)
==
0
:
setattr
(
self
,
field_name
,
None
)
setattr
(
self
,
field_name
,
None
)
...
@@ -67,28 +56,35 @@ class BaseMultimodalProcessor(ABC):
...
@@ -67,28 +56,35 @@ class BaseMultimodalProcessor(ABC):
# FIXME: not accurate, model and image specific
# FIXME: not accurate, model and image specific
self
.
NUM_TOKEN_PER_FRAME
=
330
self
.
NUM_TOKEN_PER_FRAME
=
330
# Initialize global processor first
self
.
io_executor
=
concurrent
.
futures
.
ThreadPoolExecutor
(
init_global_processor
(
self
,
server_args
)
max_workers
=
int
(
os
.
environ
.
get
(
"SGLANG_IO_WORKERS"
,
4
))
)
self
.
executor
=
concurrent
.
futures
.
ProcessPoolExecutor
(
self
.
cpu_executor
=
concurrent
.
futures
.
ProcessPoolExecutor
(
initializer
=
init_global_processor
,
mp_context
=
mp
.
get_context
(
"fork"
),
mp_context
=
mp
.
get_context
(
"fork"
),
initargs
=
(
max_workers
=
int
(
os
.
environ
.
get
(
"SGLANG_CPU_WORKERS"
,
os
.
cpu_count
())),
self
,
server_args
,
),
max_workers
=
int
(
os
.
environ
.
get
(
"SGLANG_CPU_COUNT"
,
os
.
cpu_count
())),
)
)
def
_build_processor
(
self
,
server_args
):
def
process_mm_data
(
"""Init the global processor for multi modal models."""
self
,
input_text
,
images
=
None
,
videos
=
None
,
audios
=
None
,
**
kwargs
from
sglang.srt.hf_transformers_utils
import
get_processor
):
"""
return
get_processor
(
process multimodal data with transformers AutoProcessor
server_args
.
tokenizer_path
,
"""
tokenizer_mode
=
server_args
.
tokenizer_mode
,
if
images
is
not
None
:
trust_remote_code
=
server_args
.
trust_remote_code
,
kwargs
[
"images"
]
=
images
if
videos
is
not
None
:
kwargs
[
"videos"
]
=
videos
if
audios
is
not
None
:
kwargs
[
"audios"
]
=
audios
processor
=
self
.
_processor
result
=
processor
.
__call__
(
text
=
[
input_text
],
padding
=
True
,
return_tensors
=
"pt"
,
**
kwargs
,
)
)
return
result
@
abstractmethod
@
abstractmethod
async
def
process_mm_data_async
(
async
def
process_mm_data_async
(
...
@@ -115,33 +111,9 @@ class BaseMultimodalProcessor(ABC):
...
@@ -115,33 +111,9 @@ class BaseMultimodalProcessor(ABC):
return
estimated_frames_list
return
estimated_frames_list
@
staticmethod
def
encode_video
(
video_path
,
frame_count_limit
=
None
):
if
not
os
.
path
.
exists
(
video_path
):
logger
.
error
(
f
"Video
{
video_path
}
does not exist"
)
return
[]
if
frame_count_limit
==
0
:
return
[]
def
uniform_sample
(
l
,
n
):
gap
=
len
(
l
)
/
n
idxs
=
[
int
(
i
*
gap
+
gap
/
2
)
for
i
in
range
(
n
)]
return
[
l
[
i
]
for
i
in
idxs
]
vr
=
VideoReader
(
video_path
,
ctx
=
cpu
(
0
))
sample_fps
=
round
(
vr
.
get_avg_fps
()
/
1
)
# FPS
frame_indices
=
[
i
for
i
in
range
(
0
,
len
(
vr
),
sample_fps
)]
if
frame_count_limit
is
not
None
and
len
(
frame_indices
)
>
frame_count_limit
:
frame_indices
=
uniform_sample
(
frame_indices
,
frame_count_limit
)
frames
=
vr
.
get_batch
(
frame_indices
).
asnumpy
()
frames
=
[
Image
.
fromarray
(
v
.
astype
(
"uint8"
))
for
v
in
frames
]
return
frames
def
load_mm_data
(
def
load_mm_data
(
self
,
self
,
input_ids
:
list
[
int
]
,
prompt
:
str
,
multimodal_tokens
:
MultimodalSpecialTokens
,
multimodal_tokens
:
MultimodalSpecialTokens
,
max_req_input_len
:
int
,
max_req_input_len
:
int
,
image_data
:
Optional
[
list
]
=
None
,
image_data
:
Optional
[
list
]
=
None
,
...
@@ -167,11 +139,13 @@ class BaseMultimodalProcessor(ABC):
...
@@ -167,11 +139,13 @@ class BaseMultimodalProcessor(ABC):
else
:
else
:
multimodal_tokens
.
image_token
=
multimodal_tokens
.
image_token
multimodal_tokens
.
image_token
=
multimodal_tokens
.
image_token
if
isinstance
(
input_ids
,
list
)
and
return_text
:
assert
isinstance
(
prompt
,
str
)
assert
len
(
input_ids
)
and
isinstance
(
input_ids
[
0
],
int
)
input_text
=
self
.
_processor
.
tokenizer
.
decode
(
input_ids
)
if
isinstance
(
prompt
,
list
)
and
return_text
:
assert
len
(
prompt
)
and
isinstance
(
prompt
[
0
],
int
)
prompt
=
self
.
_processor
.
tokenizer
.
decode
(
prompt
)
else
:
else
:
input_text
=
input_ids
prompt
=
prompt
if
return_text
:
if
return_text
:
import
re
import
re
...
@@ -181,7 +155,7 @@ class BaseMultimodalProcessor(ABC):
...
@@ -181,7 +155,7 @@ class BaseMultimodalProcessor(ABC):
+
")"
+
")"
)
)
# split text into list of normal text and special tokens
# split text into list of normal text and special tokens
text_parts
=
re
.
split
(
pattern
,
input_tex
t
)
text_parts
=
re
.
split
(
pattern
,
promp
t
)
# TODO(mick): load from server_args, env, or sampling_params
# TODO(mick): load from server_args, env, or sampling_params
MAX_NUM_FRAMES
=
30
MAX_NUM_FRAMES
=
30
...
@@ -217,7 +191,7 @@ class BaseMultimodalProcessor(ABC):
...
@@ -217,7 +191,7 @@ class BaseMultimodalProcessor(ABC):
):
):
# video
# video
path
=
image_file
[
len
(
"video:"
)
:]
path
=
image_file
[
len
(
"video:"
)
:]
frames
=
BaseMultimodalProcessor
.
encode_video
(
frames
=
encode_video
(
path
,
frame_count_limit
=
frames_to_process
path
,
frame_count_limit
=
frames_to_process
)
)
else
:
else
:
...
@@ -254,19 +228,9 @@ class BaseMultimodalProcessor(ABC):
...
@@ -254,19 +228,9 @@ class BaseMultimodalProcessor(ABC):
raise
RuntimeError
(
f
"An exception occurred while loading images:
{
e
}
"
)
raise
RuntimeError
(
f
"An exception occurred while loading images:
{
e
}
"
)
out
=
BaseMultiModalProcessorOutput
(
out
=
BaseMultiModalProcessorOutput
(
mm_data_hashes
=
hashes
,
image_sizes
=
image_sizes
,
images
=
images
,
images
=
images
,
audios
=
audios
,
audios
=
audios
,
input_text
=
new_text
,
input_text
=
new_text
,
)
)
out
.
normalize
()
out
.
normalize
()
return
out
return
out
def
init_global_processor
(
sglang_processor
:
BaseMultimodalProcessor
,
server_args
):
"""
Init the global processor for multimodal models."""
global
global_processor
transformers
.
logging
.
set_verbosity_error
()
global_processor
=
sglang_processor
.
_build_processor
(
server_args
=
server_args
)
python/sglang/srt/managers/multimodal_processors/clip.py
View file @
5cb552b1
import
asyncio
from
typing
import
List
,
Union
from
typing
import
List
,
Union
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
BaseMultimodalProcessor
,
BaseMultimodalProcessor
,
get_global_processor
,
)
)
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.models.clip
import
CLIPModel
from
sglang.srt.models.clip
import
CLIPModel
from
sglang.srt.utils
import
load_image
from
sglang.srt.utils
import
load_image
...
@@ -15,29 +14,6 @@ class ClipImageProcessor(BaseMultimodalProcessor):
...
@@ -15,29 +14,6 @@ class ClipImageProcessor(BaseMultimodalProcessor):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
@
staticmethod
def
_process_single_image_task
(
images
,
input_text
):
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
return
get_global_processor
()(
images
=
images
,
text
=
input_text
,
return_tensors
=
"pt"
)
async
def
_process_single_image
(
self
,
images
,
input_text
):
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
image_inputs
=
await
loop
.
run_in_executor
(
self
.
executor
,
ClipImageProcessor
.
_process_single_image_task
,
images
,
input_text
,
)
else
:
image_inputs
=
self
.
_processor
(
images
=
images
,
text
=
[
input_text
],
return_tensors
=
"pt"
)
return
image_inputs
async
def
process_mm_data_async
(
async
def
process_mm_data_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_text
,
*
args
,
**
kwargs
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_text
,
*
args
,
**
kwargs
):
):
...
@@ -56,8 +32,13 @@ class ClipImageProcessor(BaseMultimodalProcessor):
...
@@ -56,8 +32,13 @@ class ClipImageProcessor(BaseMultimodalProcessor):
else
:
else
:
images
=
load_image
(
image_data
[
0
])[
0
]
images
=
load_image
(
image_data
[
0
])[
0
]
image_inputs
=
await
self
.
_
process_
single_image
(
images
,
input_text
)
image_inputs
=
self
.
process_
mm_data
(
input_text
=
input_text
,
images
=
images
)
image_inputs
[
"data_hashes"
]
=
[
hash
(
str
(
image_data
))]
image_inputs
[
"data_hashes"
]
=
[
hash
(
str
(
image_data
))]
image_inputs
[
"input_ids"
]
=
image_inputs
[
"input_ids"
].
tolist
()[
0
]
image_inputs
[
"input_ids"
]
=
image_inputs
[
"input_ids"
].
tolist
()[
0
]
image_inputs
[
"mm_items"
]
=
[
MultimodalDataItem
(
pixel_values
=
image_inputs
[
"pixel_values"
],
modality
=
Modality
.
IMAGE
)
]
return
image_inputs
return
image_inputs
python/sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py
View file @
5cb552b1
...
@@ -16,15 +16,14 @@
...
@@ -16,15 +16,14 @@
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import
asyncio
import
torch
import
torch
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
BaseMultimodalProcessor
,
BaseMultimodalProcessor
,
MultimodalSpecialTokens
,
MultimodalSpecialTokens
,
get_global_processor
,
)
)
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.models.deepseek_vl2
import
DeepseekVL2ForCausalLM
from
sglang.srt.models.deepseek_vl2
import
DeepseekVL2ForCausalLM
...
@@ -35,51 +34,6 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
...
@@ -35,51 +34,6 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
self
.
IMAGE_TOKEN
=
"<image>"
self
.
IMAGE_TOKEN
=
"<image>"
@
staticmethod
def
_process_images_task
(
image
,
input_text
,
max_req_input_len
):
processor
=
get_global_processor
()
res
=
processor
.
__call__
(
conversations
=
input_text
,
images
=
image
,
max_req_input_len
=
max_req_input_len
)
image_token_id
=
processor
.
image_token_id
res
[
"im_token_id"
]
=
image_token_id
return
res
async
def
_process_images
(
self
,
image_data
,
input_text
,
max_req_input_len
):
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
image_inputs
=
await
loop
.
run_in_executor
(
self
.
executor
,
DeepseekVL2ImageProcessor
.
_process_images_task
,
image_data
,
input_text
,
max_req_input_len
,
)
else
:
image_inputs
=
self
.
_process_images_task
(
image_data
,
input_text
,
max_req_input_len
)
return
image_inputs
async
def
_process_images
(
self
,
image_data
,
input_text
,
max_req_input_len
):
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
image_inputs
=
await
loop
.
run_in_executor
(
self
.
executor
,
DeepseekVL2ImageProcessor
.
_process_images_task
,
image_data
,
input_text
,
max_req_input_len
,
)
else
:
image_inputs
=
self
.
_process_images_task
(
image_data
,
input_text
,
max_req_input_len
)
return
image_inputs
async
def
process_mm_data_async
(
async
def
process_mm_data_async
(
self
,
image_data
,
input_ids
,
request_obj
,
max_req_input_len
,
*
args
,
**
kwargs
self
,
image_data
,
input_ids
,
request_obj
,
max_req_input_len
,
*
args
,
**
kwargs
):
):
...
@@ -89,8 +43,6 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
...
@@ -89,8 +43,6 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
if
not
isinstance
(
image_data
,
list
):
if
not
isinstance
(
image_data
,
list
):
image_data
=
[
image_data
]
image_data
=
[
image_data
]
images
,
image_sizes
=
[],
[]
image_token
=
self
.
IMAGE_TOKEN
image_token
=
self
.
IMAGE_TOKEN
base_output
=
self
.
load_mm_data
(
base_output
=
self
.
load_mm_data
(
input_ids
,
input_ids
,
...
@@ -98,8 +50,11 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
...
@@ -98,8 +50,11 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
image_token
),
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
image_token
),
max_req_input_len
=
max_req_input_len
,
max_req_input_len
=
max_req_input_len
,
)
)
res
=
await
self
.
_process_images
(
res
=
self
.
process_mm_data
(
base_output
.
images
,
base_output
.
input_text
,
max_req_input_len
input_text
=
base_output
.
input_text
,
images
=
base_output
.
images
,
max_req_input_len
=
max_req_input_len
,
conversations
=
base_output
.
input_text
,
)
)
images_seq_mask
=
res
[
"images_seq_mask"
]
images_seq_mask
=
res
[
"images_seq_mask"
]
images_spatial_crop
=
res
[
"images_spatial_crop"
]
images_spatial_crop
=
res
[
"images_spatial_crop"
]
...
@@ -107,13 +62,17 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
...
@@ -107,13 +62,17 @@ class DeepseekVL2ImageProcessor(BaseMultimodalProcessor):
batched_images_spatial_crop
.
append
(
images_spatial_crop
)
batched_images_spatial_crop
.
append
(
images_spatial_crop
)
batched_images_spatial_crop
=
torch
.
stack
(
batched_images_spatial_crop
,
dim
=
0
)
batched_images_spatial_crop
=
torch
.
stack
(
batched_images_spatial_crop
,
dim
=
0
)
items
=
[]
item
=
MultimodalDataItem
(
pixel_values
=
res
[
"images"
],
modality
=
Modality
.
IMAGE
,
image_emb_mask
=
images_seq_mask
,
image_spatial_crop
=
batched_images_spatial_crop
,
)
items
+=
[
item
]
return
{
return
{
"mm_items"
:
items
,
"input_ids"
:
res
[
"input_ids"
].
tolist
(),
"input_ids"
:
res
[
"input_ids"
].
tolist
(),
"pixel_values"
:
res
[
"images"
],
"im_token_id"
:
self
.
_processor
.
image_token_id
,
"im_token_id"
:
res
[
"im_token_id"
],
"data_hashes"
:
base_output
.
mm_data_hashes
,
"image_sizes"
:
image_sizes
,
"images_emb_mask"
:
images_seq_mask
,
"image_spatial_crop"
:
batched_images_spatial_crop
,
"modalities"
:
request_obj
.
modalities
or
[
"image"
],
}
}
python/sglang/srt/managers/multimodal_processors/gemma3.py
View file @
5cb552b1
...
@@ -7,8 +7,8 @@ from sglang.srt.managers.multimodal_processor import (
...
@@ -7,8 +7,8 @@ from sglang.srt.managers.multimodal_processor import (
)
)
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
MultimodalSpecialTokens
,
MultimodalSpecialTokens
,
get_global_processor
,
)
)
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.models.gemma3_mm
import
Gemma3ForConditionalGeneration
from
sglang.srt.models.gemma3_mm
import
Gemma3ForConditionalGeneration
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/image_processing_gemma3_fast.py
# Copied from: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma3/image_processing_gemma3_fast.py
...
@@ -25,28 +25,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
...
@@ -25,28 +25,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
self
.
IM_START_TOKEN_ID
=
hf_config
.
boi_token_index
self
.
IM_START_TOKEN_ID
=
hf_config
.
boi_token_index
self
.
IM_END_TOKEN_ID
=
hf_config
.
eoi_token_index
self
.
IM_END_TOKEN_ID
=
hf_config
.
eoi_token_index
async
def
_process_single_image
(
self
,
images
,
input_text
)
->
dict
:
if
isinstance
(
images
,
list
)
and
len
(
images
)
==
0
:
images
=
None
processor
=
get_global_processor
()
result
=
processor
.
__call__
(
text
=
[
input_text
],
images
=
images
,
padding
=
True
,
return_tensors
=
"pt"
,
# if RGBA, this needs to be set
# images_kwargs={
# "input_data_format": ChannelDimension.FIRST
# }
)
pixel_values
=
getattr
(
result
,
"pixel_values"
,
None
)
return
{
"input_ids"
:
result
.
input_ids
,
"pixel_values"
:
pixel_values
,
}
async
def
process_mm_data_async
(
async
def
process_mm_data_async
(
self
,
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
image_data
:
List
[
Union
[
str
,
bytes
]],
...
@@ -63,21 +41,28 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
...
@@ -63,21 +41,28 @@ class Gemma3SGLangImageProcessor(SGLangBaseProcessor):
image_token
=
self
.
IMAGE_TOKEN
image_token
=
self
.
IMAGE_TOKEN
base_output
=
self
.
load_mm_data
(
base_output
=
self
.
load_mm_data
(
input_ids
=
input_ids
,
prompt
=
input_ids
,
image_data
=
image_data
,
image_data
=
image_data
,
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
image_token
),
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
image_token
),
max_req_input_len
=
max_req_input_len
,
max_req_input_len
=
max_req_input_len
,
discard_alpha_channel
=
True
,
discard_alpha_channel
=
True
,
)
)
ret
=
await
self
.
_
process_
single_image
(
ret
=
self
.
process_
mm_data
(
input_text
=
base_output
.
input_text
,
images
=
base_output
.
images
input_text
=
base_output
.
input_text
,
images
=
base_output
.
images
)
)
items
=
[]
for
i
,
image
in
enumerate
(
base_output
.
images
):
item
=
MultimodalDataItem
(
pixel_values
=
ret
[
"pixel_values"
][
i
],
modality
=
Modality
.
IMAGE
,
)
items
+=
[
item
]
return
{
return
{
"mm_items"
:
items
,
"input_ids"
:
ret
[
"input_ids"
].
flatten
().
tolist
(),
"input_ids"
:
ret
[
"input_ids"
].
flatten
().
tolist
(),
"pixel_values"
:
ret
[
"pixel_values"
],
"data_hashes"
:
base_output
.
mm_data_hashes
,
"im_start_id"
:
self
.
IM_START_TOKEN_ID
,
"im_start_id"
:
self
.
IM_START_TOKEN_ID
,
"im_end_id"
:
self
.
IM_END_TOKEN_ID
,
"im_end_id"
:
self
.
IM_END_TOKEN_ID
,
}
}
python/sglang/srt/managers/multimodal_processors/janus_pro.py
View file @
5cb552b1
import
asyncio
from
typing
import
List
,
Union
from
typing
import
List
,
Union
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
BaseMultimodalProcessor
,
BaseMultimodalProcessor
,
MultimodalSpecialTokens
,
MultimodalSpecialTokens
,
get_global_processor
,
)
)
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.models.deepseek_janus_pro
import
MultiModalityCausalLM
from
sglang.srt.models.deepseek_janus_pro
import
MultiModalityCausalLM
...
@@ -15,37 +14,6 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
...
@@ -15,37 +14,6 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
@
staticmethod
def
_process_images_task
(
images
,
input_text
):
processor
=
get_global_processor
()
result
=
processor
.
__call__
(
prompt
=
input_text
,
images
=
images
,
return_tensors
=
"pt"
)
return
{
"input_ids"
:
result
[
"input_ids"
],
"pixel_values"
:
result
[
"pixel_values"
],
"images_emb_mask"
:
result
[
"images_emb_mask"
],
"im_start_id"
:
processor
.
image_start_id
,
"im_end_id"
:
processor
.
image_end_id
,
"im_token_id"
:
processor
.
image_id
,
}
async
def
_process_images
(
self
,
images
,
input_text
):
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
image_inputs
=
await
loop
.
run_in_executor
(
self
.
executor
,
JanusProImageProcessor
.
_process_images_task
,
images
,
input_text
,
)
else
:
image_inputs
=
self
.
_processor
(
images
=
images
,
text
=
input_text
,
return_tensors
=
"pt"
)
return
image_inputs
async
def
process_mm_data_async
(
async
def
process_mm_data_async
(
self
,
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
image_data
:
List
[
Union
[
str
,
bytes
]],
...
@@ -60,25 +28,31 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
...
@@ -60,25 +28,31 @@ class JanusProImageProcessor(BaseMultimodalProcessor):
if
not
isinstance
(
image_data
,
list
):
if
not
isinstance
(
image_data
,
list
):
image_data
=
[
image_data
]
image_data
=
[
image_data
]
processor
=
self
.
_processor
base_out
=
self
.
load_mm_data
(
base_out
=
self
.
load_mm_data
(
input_ids
=
input_ids
,
prompt
=
input_ids
,
image_data
=
image_data
,
image_data
=
image_data
,
multimodal_tokens
=
MultimodalSpecialTokens
(
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
processor
.
image_tag
),
image_token
=
"<image_placeholder>"
),
max_req_input_len
=
max_req_input_len
,
max_req_input_len
=
max_req_input_len
,
)
)
images
=
base_out
.
images
images
=
base_out
.
images
res
=
await
self
.
_process_images
(
images
=
images
,
input_text
=
base_out
.
input_text
)
res
=
self
.
process_mm_data
(
# print(res)
input_text
=
base_out
.
input_text
,
# print(base_out)
prompt
=
base_out
.
input_text
,
# print("", res["images_emb_mask"].shape)
images
=
images
,
)
return
{
return
{
"mm_items"
:
[
MultimodalDataItem
(
pixel_values
=
res
[
"pixel_values"
],
image_emb_mask
=
res
[
"images_emb_mask"
],
modality
=
Modality
.
IMAGE
,
)
],
"input_ids"
:
res
[
"input_ids"
].
flatten
().
tolist
(),
"input_ids"
:
res
[
"input_ids"
].
flatten
().
tolist
(),
"pixel_values"
:
res
[
"pixel_values"
],
"im_start_id"
:
processor
.
image_start_id
,
"images_emb_mask"
:
res
[
"images_emb_mask"
],
"im_end_id"
:
processor
.
image_end_id
,
"data_hashes"
:
base_out
.
mm_data_hashes
,
"im_token_id"
:
processor
.
image_id
,
"im_start_id"
:
res
[
"im_start_id"
],
"im_end_id"
:
res
[
"im_end_id"
],
"im_token_id"
:
res
[
"im_token_id"
],
}
}
python/sglang/srt/managers/multimodal_processors/llava.py
View file @
5cb552b1
...
@@ -5,8 +5,8 @@ import numpy as np
...
@@ -5,8 +5,8 @@ import numpy as np
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
BaseMultimodalProcessor
,
BaseMultimodalProcessor
,
get_global_processor
,
)
)
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.mm_utils
import
expand2square
,
process_anyres_image
from
sglang.srt.mm_utils
import
expand2square
,
process_anyres_image
from
sglang.srt.models.llava
import
LlavaMistralForCausalLM
,
LlavaQwenForCausalLM
from
sglang.srt.models.llava
import
LlavaMistralForCausalLM
,
LlavaQwenForCausalLM
from
sglang.srt.models.llavavid
import
LlavaVidForCausalLM
from
sglang.srt.models.llavavid
import
LlavaVidForCausalLM
...
@@ -25,11 +25,10 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
...
@@ -25,11 +25,10 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
image_data
:
Union
[
str
,
bytes
],
image_data
:
Union
[
str
,
bytes
],
image_aspect_ratio
:
Optional
[
str
]
=
None
,
image_aspect_ratio
:
Optional
[
str
]
=
None
,
image_grid_pinpoints
:
Optional
[
str
]
=
None
,
image_grid_pinpoints
:
Optional
[
str
]
=
None
,
image_
processor
=
None
,
processor
=
None
,
):
):
processor
=
get_global_processor
()
image_processor
=
image_processor
or
processor
.
image_processor
image_processor
=
processor
.
image_processor
try
:
try
:
image
,
image_size
=
load_image
(
image_data
)
image
,
image_size
=
load_image
(
image_data
)
...
@@ -72,18 +71,22 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
...
@@ -72,18 +71,22 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
async
def
_process_single_image
(
async
def
_process_single_image
(
self
,
image_data
:
Union
[
bytes
,
str
],
aspect_ratio
:
str
,
grid_pinpoints
:
str
self
,
image_data
:
Union
[
bytes
,
str
],
aspect_ratio
:
str
,
grid_pinpoints
:
str
):
):
if
self
.
executor
is
not
None
:
if
self
.
cpu_
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
loop
=
asyncio
.
get_event_loop
()
return
await
loop
.
run_in_executor
(
return
await
loop
.
run_in_executor
(
self
.
executor
,
self
.
cpu_
executor
,
LlavaImageProcessor
.
_process_single_image_task
,
LlavaImageProcessor
.
_process_single_image_task
,
image_data
,
image_data
,
aspect_ratio
,
aspect_ratio
,
grid_pinpoints
,
grid_pinpoints
,
self
.
_processor
,
)
)
else
:
else
:
return
self
.
_process_single_image_task
(
return
self
.
_process_single_image_task
(
image_data
,
aspect_ratio
,
grid_pinpoints
image_data
,
aspect_ratio
,
grid_pinpoints
,
self
.
_processor
.
image_processor
,
)
)
async
def
process_mm_data_async
(
async
def
process_mm_data_async
(
...
@@ -134,14 +137,22 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
...
@@ -134,14 +137,22 @@ class LlavaImageProcessor(BaseMultimodalProcessor):
pixel_values
,
image_hash
,
image_size
=
await
self
.
_process_single_image
(
pixel_values
,
image_hash
,
image_size
=
await
self
.
_process_single_image
(
image_data
[
0
],
aspect_ratio
,
grid_pinpoints
image_data
[
0
],
aspect_ratio
,
grid_pinpoints
)
)
data_hashes
=
[
image_hash
]
image_sizes
=
[
image_size
]
image_sizes
=
[
image_size
]
else
:
else
:
raise
ValueError
(
f
"Invalid image data:
{
image_data
}
"
)
raise
ValueError
(
f
"Invalid image data:
{
image_data
}
"
)
modality
=
Modality
.
IMAGE
if
isinstance
(
request_obj
.
modalities
,
list
):
if
request_obj
.
modalities
[
0
]
==
"multi-images"
:
modality
=
Modality
.
MULTI_IMAGES
elif
request_obj
.
modalities
[
0
]
==
"video"
:
modality
=
Modality
.
VIDEO
return
{
return
{
"pixel_values"
:
pixel_values
,
"mm_items"
:
[
"data_hashes"
:
data_hashes
,
MultimodalDataItem
(
"image_sizes"
:
image_sizes
,
pixel_values
=
pixel_values
,
"modalities"
:
request_obj
.
modalities
or
[
"image"
],
image_sizes
=
image_sizes
,
modality
=
modality
,
)
],
}
}
python/sglang/srt/managers/multimodal_processors/minicpm.py
View file @
5cb552b1
import
asyncio
from
typing
import
List
,
Union
from
typing
import
List
,
Union
import
torch
import
torch
from
transformers
import
BaseImageProcessorFast
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
BaseMultimodalProcessor
,
BaseMultimodalProcessor
,
MultimodalSpecialTokens
,
MultimodalSpecialTokens
,
get_global_processor
,
)
)
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.models.minicpmo
import
MiniCPMO
from
sglang.srt.models.minicpmo
import
MiniCPMO
from
sglang.srt.models.minicpmv
import
MiniCPMV
from
sglang.srt.models.minicpmv
import
MiniCPMV
...
@@ -21,19 +21,23 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
...
@@ -21,19 +21,23 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
self
.
image_token
=
"(<image>./</image>)"
self
.
image_token
=
"(<image>./</image>)"
self
.
audio_token
=
"(<audio>./</audio>)"
self
.
audio_token
=
"(<audio>./</audio>)"
@
staticmethod
def
process_data_task
(
self
,
input_text
,
images
=
None
,
audios
=
None
):
def
_process_data_task
(
input_text
,
images
=
None
,
audios
=
None
):
if
isinstance
(
images
,
list
)
and
len
(
images
)
==
0
:
if
isinstance
(
images
,
list
)
and
len
(
images
)
==
0
:
images
=
None
images
=
None
if
isinstance
(
audios
,
list
)
and
len
(
audios
)
==
0
:
if
isinstance
(
audios
,
list
)
and
len
(
audios
)
==
0
:
audios
=
None
audios
=
None
result
=
get_global_processor
().
__call__
(
processor
=
self
.
_processor
args
=
{}
if
isinstance
(
processor
,
BaseImageProcessorFast
):
args
[
"device"
]
=
"cuda"
result
=
self
.
_processor
.
__call__
(
text
=
input_text
,
text
=
input_text
,
images
=
images
,
images
=
images
,
audios
=
audios
,
audios
=
audios
,
return_tensors
=
"pt"
,
return_tensors
=
"pt"
,
chunk_input
=
True
,
chunk_input
=
True
,
**
args
,
)
)
return
{
return
{
"input_ids"
:
result
.
input_ids
,
"input_ids"
:
result
.
input_ids
,
...
@@ -44,23 +48,6 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
...
@@ -44,23 +48,6 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
"audio_bounds"
:
getattr
(
result
,
"audio_bounds"
,
None
),
"audio_bounds"
:
getattr
(
result
,
"audio_bounds"
,
None
),
}
}
async
def
_process_data
(
self
,
images
,
input_text
,
audios
=
None
):
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
multimodal_data_inputs
=
await
loop
.
run_in_executor
(
self
.
executor
,
MiniCPMMultimodalProcessor
.
_process_data_task
,
input_text
,
images
,
audios
,
)
else
:
multimodal_data_inputs
=
self
.
_processor
(
images
=
images
,
text
=
input_text
,
audios
=
audios
,
return_tensors
=
"pt"
)
return
multimodal_data_inputs
async
def
process_mm_data_async
(
async
def
process_mm_data_async
(
self
,
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
image_data
:
List
[
Union
[
str
,
bytes
]],
...
@@ -77,7 +64,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
...
@@ -77,7 +64,7 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
audio_data
=
[
audio_data
]
audio_data
=
[
audio_data
]
base_output
=
self
.
load_mm_data
(
base_output
=
self
.
load_mm_data
(
input_ids
=
input_ids
,
prompt
=
input_ids
,
max_req_input_len
=
max_req_input_len
,
max_req_input_len
=
max_req_input_len
,
audio_data
=
audio_data
,
audio_data
=
audio_data
,
image_data
=
image_data
,
image_data
=
image_data
,
...
@@ -88,9 +75,9 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
...
@@ -88,9 +75,9 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
if
base_output
is
None
:
if
base_output
is
None
:
return
None
return
None
res
=
await
self
.
_process_data
(
res
=
self
.
process_mm_data
(
images
=
base_output
.
images
,
input_text
=
base_output
.
input_text
,
input_text
=
base_output
.
input_text
,
images
=
base_output
.
images
,
audios
=
base_output
.
audios
,
audios
=
base_output
.
audios
,
)
)
...
@@ -142,23 +129,33 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
...
@@ -142,23 +129,33 @@ class MiniCPMMultimodalProcessor(BaseMultimodalProcessor):
tgt_sizes_flat
+=
[
tgt_n
]
tgt_sizes_flat
+=
[
tgt_n
]
pixel_values
=
pixel_values_flat
pixel_values
=
pixel_values_flat
if
len
(
tgt_sizes_flat
)
==
0
:
tgt_sizes
=
None
items
=
[]
else
:
if
len
(
pixel_values
)
!=
0
:
tgt_sizes
=
torch
.
stack
(
tgt_sizes_flat
)
item
=
MultimodalDataItem
(
if
not
isinstance
(
res
[
"audio_features"
],
list
):
pixel_values
=
pixel_values
,
res
[
"audio_features"
]
=
[
res
[
"audio_features"
]]
tgt_size
=
tgt_sizes_flat
,
modality
=
Modality
.
IMAGE
,
)
items
+=
[
item
]
if
(
"audio_features"
in
res
and
res
[
"audio_features"
]
is
not
None
and
len
(
res
[
"audio_features"
])
!=
0
):
item
=
MultimodalDataItem
(
audio_features
=
[
res
[
"audio_features"
]],
audio_feature_lens
=
res
[
"audio_feature_lens"
],
modality
=
Modality
.
AUDIO
,
)
items
+=
[
item
]
return
{
return
{
"mm_items"
:
items
,
"input_ids"
:
res
[
"input_ids"
].
flatten
().
tolist
(),
"input_ids"
:
res
[
"input_ids"
].
flatten
().
tolist
(),
"pixel_values"
:
pixel_values
,
"tgt_sizes"
:
tgt_sizes
,
"data_hashes"
:
base_output
.
mm_data_hashes
,
"modalities"
:
request_obj
.
modalities
or
[
"image"
],
"audio_start_id"
:
audio_start_id
,
"audio_start_id"
:
audio_start_id
,
"audio_end_id"
:
audio_end_id
,
"audio_end_id"
:
audio_end_id
,
"audio_features"
:
res
[
"audio_features"
],
"audio_bounds"
:
res
[
"audio_bounds"
],
"audio_feature_lens"
:
res
[
"audio_feature_lens"
],
"im_token_id"
:
im_token_id
,
"im_token_id"
:
im_token_id
,
"im_start_id"
:
tokenizer
.
im_start_id
,
"im_start_id"
:
tokenizer
.
im_start_id
,
"im_end_id"
:
tokenizer
.
im_end_id
,
"im_end_id"
:
tokenizer
.
im_end_id
,
...
...
python/sglang/srt/managers/multimodal_processors/mlama.py
View file @
5cb552b1
import
asyncio
from
typing
import
List
,
Union
from
typing
import
List
,
Union
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
BaseMultimodalProcessor
,
BaseMultimodalProcessor
,
get_global_processor
,
)
)
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.models.mllama
import
MllamaForConditionalGeneration
from
sglang.srt.models.mllama
import
MllamaForConditionalGeneration
from
sglang.srt.utils
import
load_image
from
sglang.srt.utils
import
load_image
...
@@ -15,25 +14,6 @@ class MllamaImageProcessor(BaseMultimodalProcessor):
...
@@ -15,25 +14,6 @@ class MllamaImageProcessor(BaseMultimodalProcessor):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
@
staticmethod
def
_process_single_image_task
(
images
,
input_text
):
# input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
return
get_global_processor
()(
images
,
input_text
,
return_tensors
=
"pt"
)
async
def
_process_single_image
(
self
,
images
,
input_text
):
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
image_inputs
=
await
loop
.
run_in_executor
(
self
.
executor
,
MllamaImageProcessor
.
_process_single_image_task
,
images
,
input_text
,
)
else
:
image_inputs
=
self
.
_processor
(
images
,
input_text
,
return_tensors
=
"pt"
)
return
image_inputs
async
def
process_mm_data_async
(
async
def
process_mm_data_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_text
,
*
args
,
**
kwargs
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
input_text
,
*
args
,
**
kwargs
):
):
...
@@ -52,8 +32,15 @@ class MllamaImageProcessor(BaseMultimodalProcessor):
...
@@ -52,8 +32,15 @@ class MllamaImageProcessor(BaseMultimodalProcessor):
else
:
else
:
images
=
load_image
(
image_data
[
0
])[
0
]
images
=
load_image
(
image_data
[
0
])[
0
]
image_inputs
=
await
self
.
_process_single_image
(
images
,
input_text
)
image_inputs
=
self
.
process_mm_data
(
input_text
=
input_text
,
images
=
images
)
image_inputs
[
"data_hashes"
]
=
[
hash
(
str
(
image_data
))]
image_inputs
[
"input_ids"
]
=
image_inputs
[
"input_ids"
].
tolist
()[
0
]
image_inputs
[
"input_ids"
]
=
image_inputs
[
"input_ids"
].
tolist
()[
0
]
image_inputs
[
"mm_items"
]
=
[
MultimodalDataItem
(
pixel_values
=
image_inputs
[
"pixel_values"
],
aspect_ratio_id
=
image_inputs
[
"aspect_ratio_ids"
],
aspect_ratio_mask
=
image_inputs
[
"aspect_ratio_mask"
],
modality
=
Modality
.
IMAGE
,
)
]
return
image_inputs
return
image_inputs
python/sglang/srt/managers/multimodal_processors/qwen_vl.py
View file @
5cb552b1
import
asyncio
import
asyncio
import
math
import
math
import
time
from
typing
import
List
,
Union
from
typing
import
List
,
Union
import
torch
import
torch
...
@@ -11,8 +10,8 @@ from sglang.srt.managers.multimodal_processor import (
...
@@ -11,8 +10,8 @@ from sglang.srt.managers.multimodal_processor import (
)
)
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
from
sglang.srt.managers.multimodal_processors.base_processor
import
(
MultimodalSpecialTokens
,
MultimodalSpecialTokens
,
get_global_processor
,
)
)
from
sglang.srt.managers.schedule_batch
import
Modality
,
MultimodalDataItem
from
sglang.srt.models.qwen2_5_vl
import
Qwen2_5_VLForConditionalGeneration
from
sglang.srt.models.qwen2_5_vl
import
Qwen2_5_VLForConditionalGeneration
from
sglang.srt.models.qwen2_vl
import
Qwen2VLForConditionalGeneration
from
sglang.srt.models.qwen2_vl
import
Qwen2VLForConditionalGeneration
...
@@ -34,45 +33,15 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
...
@@ -34,45 +33,15 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
self
.
MAX_PIXELS
=
16384
*
28
*
28
self
.
MAX_PIXELS
=
16384
*
28
*
28
self
.
MAX_RATIO
=
200
self
.
MAX_RATIO
=
200
@
staticmethod
def
_process_images_task
(
images
,
input_text
,
_hf_config
):
if
isinstance
(
images
,
list
)
and
len
(
images
)
==
0
:
images
=
None
result
=
get_global_processor
().
__call__
(
text
=
[
input_text
],
images
=
images
,
padding
=
True
,
return_tensors
=
"pt"
)
return
{
"input_ids"
:
result
.
input_ids
,
"pixel_values"
:
getattr
(
result
,
"pixel_values"
,
None
),
"image_grid_thw"
:
getattr
(
result
,
"image_grid_thw"
,
None
),
"second_per_grid_ts"
:
getattr
(
result
,
"second_per_grid_ts"
,
None
),
"video_grid_thws"
:
getattr
(
result
,
"video_grid_thws"
,
None
),
}
async
def
_process_single_image
(
self
,
images
,
input_text
)
->
dict
:
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
return
await
loop
.
run_in_executor
(
self
.
executor
,
Qwen2_5VLImageProcessor
.
_process_images_task
,
images
,
input_text
,
self
.
hf_config
,
)
else
:
return
self
.
_process_images_task
(
images
,
input_text
,
self
.
hf_config
)
async
def
process_mm_data_async
(
async
def
process_mm_data_async
(
self
,
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
image_data
:
List
[
Union
[
str
,
bytes
]],
input_ids
,
prompt
,
request_obj
,
request_obj
,
max_req_input_len
,
max_req_input_len
,
*
args
,
*
args
,
**
kwargs
,
**
kwargs
,
):
):
start
=
time
.
time
()
if
not
image_data
:
if
not
image_data
:
return
None
return
None
if
isinstance
(
image_data
,
str
):
if
isinstance
(
image_data
,
str
):
...
@@ -80,7 +49,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
...
@@ -80,7 +49,7 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
image_token
=
self
.
IMAGE_TOKEN
image_token
=
self
.
IMAGE_TOKEN
base_output
=
self
.
load_mm_data
(
base_output
=
self
.
load_mm_data
(
input_ids
=
input_ids
,
prompt
=
prompt
,
image_data
=
image_data
,
image_data
=
image_data
,
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
image_token
),
multimodal_tokens
=
MultimodalSpecialTokens
(
image_token
=
image_token
),
max_req_input_len
=
max_req_input_len
,
max_req_input_len
=
max_req_input_len
,
...
@@ -144,24 +113,32 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
...
@@ -144,24 +113,32 @@ class Qwen2_5VLImageProcessor(SGLangBaseProcessor):
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
return
math
.
floor
(
number
/
factor
)
*
factor
return
math
.
floor
(
number
/
factor
)
*
factor
images
=
[
resize_image
(
image
)
for
image
in
base_output
.
images
]
async
def
resize_image_async
(
image
):
return
resize_image
(
image
)
ret
=
await
self
.
_process_single_image
(
resize_tasks
=
[
resize_image_async
(
image
)
for
image
in
base_output
.
images
]
images
=
images
,
input_text
=
base_output
.
input_text
resized_images
=
await
asyncio
.
gather
(
*
resize_tasks
)
ret
=
self
.
process_mm_data
(
input_text
=
base_output
.
input_text
,
images
=
resized_images
,
)
)
image_grid_thws
=
torch
.
concat
([
ret
[
"image_grid_thw"
]])
image_grid_thws
=
torch
.
concat
([
ret
[
"image_grid_thw"
]])
video_grid_thws
=
None
return
{
return
{
"input_ids"
:
ret
[
"input_ids"
].
flatten
().
tolist
(),
"input_ids"
:
ret
[
"input_ids"
].
flatten
().
tolist
(),
"pixel_values"
:
ret
[
"pixel_values"
],
"mm_items"
:
[
"data_hashes"
:
base_output
.
mm_data_hashes
,
MultimodalDataItem
(
"modalities"
:
request_obj
.
modalities
or
[
"image"
],
pixel_values
=
ret
[
"pixel_values"
],
"image_grid_thws"
:
image_grid_thws
,
image_grid_thws
=
image_grid_thws
,
"video_grid_thws"
:
video_grid_thws
,
# TODO
video_grid_thws
=
None
,
second_per_grid_ts
=
ret
.
get
(
"second_per_grid_ts"
,
None
),
modality
=
Modality
.
IMAGE
,
)
],
"im_start_id"
:
self
.
IM_START_TOKEN_ID
,
"im_start_id"
:
self
.
IM_START_TOKEN_ID
,
"im_end_id"
:
self
.
IM_END_TOKEN_ID
,
"im_end_id"
:
self
.
IM_END_TOKEN_ID
,
"im_token_id"
:
self
.
image_token_id
,
"im_token_id"
:
self
.
image_token_id
,
"video_token_id"
:
self
.
video_token_id
,
"video_token_id"
:
self
.
video_token_id
,
"second_per_grid_ts"
:
ret
[
"second_per_grid_ts"
],
}
}
python/sglang/srt/managers/schedule_batch.py
View file @
5cb552b1
from
__future__
import
annotations
from
__future__
import
annotations
from
enum
import
Enum
,
auto
# Copyright 2023-2024 SGLang Team
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
...
@@ -51,7 +53,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw
...
@@ -51,7 +53,7 @@ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, Forw
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_batch_info
import
SamplingBatchInfo
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
get_compiler_backend
from
sglang.srt.utils
import
flatten_nested_list
,
get_compiler_backend
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
...
@@ -143,165 +145,185 @@ class FINISH_ABORT(BaseFinishReason):
...
@@ -143,165 +145,185 @@ class FINISH_ABORT(BaseFinishReason):
}
}
class
Modality
(
Enum
):
IMAGE
=
auto
()
MULTI_IMAGES
=
auto
()
VIDEO
=
auto
()
AUDIO
=
auto
()
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
MultimodalInputs
:
class
MultimodalDataItem
:
"""The image related inputs."""
"""
A single multimodal data, from a single image/video/audio or other
"""
modality
:
Modality
hash
:
int
=
None
pad_value
:
int
=
None
pixel_values
:
Union
[
torch
.
Tensor
,
np
.
array
]
aspect_ratio_id
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
data_hashes
:
Optional
[
list
]
=
None
aspect_ratio_mask
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
image_sizes
:
Optional
[
list
]
=
None
image_sizes
:
Tuple
[
int
,
int
]
=
None
image_offsets
:
Optional
[
list
]
=
None
image_offsets
:
Optional
[
list
]
=
None
# the real data, pixel_values or audio_features
# data: Union[List[torch.Tensor], List[np.array]]
pixel_values
:
Union
[
torch
.
Tensor
,
np
.
array
]
=
None
image_grid_thws
:
Union
[
torch
.
Tensor
,
np
.
array
]
=
None
video_grid_thws
:
Union
[
torch
.
Tensor
,
np
.
array
]
=
None
image_emb_mask
:
Optional
[
torch
.
Tensor
]
=
None
image_spatial_crop
:
Optional
[
torch
.
Tensor
]
=
None
second_per_grid_ts
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
# [num_images, (n, w, h)]
tgt_size
:
Tuple
[
int
,
int
]
=
None
audio_features
:
Union
[
torch
.
Tensor
,
np
.
array
]
=
None
audio_feature_lens
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
@
staticmethod
def
is_empty_list
(
l
):
if
l
is
None
:
return
True
return
len
([
item
for
item
in
flatten_nested_list
(
l
)
if
item
is
not
None
])
==
0
def
set_pad_value
(
self
):
"""
Set the pad value after first hashign the data
"""
def
hash_feature
(
f
):
if
isinstance
(
f
,
list
):
return
hash
(
tuple
(
flatten_nested_list
(
f
)))
elif
isinstance
(
f
,
np
.
ndarray
):
arr
=
np
.
ascontiguousarray
(
f
)
arr_bytes
=
arr
.
tobytes
()
return
hash
(
arr_bytes
)
return
hash
(
f
)
if
self
.
is_audio
():
self
.
hash
=
hash_feature
(
self
.
audio_features
)
else
:
self
.
hash
=
hash_feature
(
self
.
pixel_values
)
assert
self
.
hash
is
not
None
self
.
pad_value
=
self
.
hash
%
(
1
<<
30
)
def
is_audio
(
self
):
return
(
self
.
modality
==
Modality
.
AUDIO
)
and
not
MultimodalDataItem
.
is_empty_list
(
self
.
audio_features
)
def
is_image
(
self
):
return
(
self
.
modality
==
Modality
.
IMAGE
or
self
.
modality
==
Modality
.
MULTI_IMAGES
)
and
not
MultimodalDataItem
.
is_empty_list
(
self
.
pixel_values
)
def
is_video
(
self
):
return
(
self
.
modality
==
Modality
.
VIDEO
)
and
not
MultimodalDataItem
.
is_empty_list
(
self
.
pixel_values
)
def
validate
(
self
):
...
# TODO
@
dataclasses
.
dataclass
class
MultimodalInputs
:
"""The multimodal data related inputs."""
# items of data
mm_items
:
List
[
MultimodalDataItem
]
image_pad_len
:
Optional
[
list
]
=
None
image_pad_len
:
Optional
[
list
]
=
None
pad_values
:
Optional
[
list
]
=
None
modalities
:
Optional
[
list
]
=
None
num_image_tokens
:
Optional
[
int
]
=
None
num_image_tokens
:
Optional
[
int
]
=
None
# Llava related
aspect_ratio_ids
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
aspect_ratio_mask
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
# QWen2-VL related
# QWen2-VL related
# [num_of_images, t, h, w]
image_grid_thws
:
torch
.
Tensor
=
None
mrope_position_delta
:
Optional
[
torch
.
Tensor
]
=
None
mrope_position_delta
:
Optional
[
torch
.
Tensor
]
=
None
# Qwen2-VL video related
video_token_id
:
Optional
[
int
]
=
None
video_grid_thws
:
List
[
Tuple
[
int
,
int
,
int
]]
=
None
second_per_grid_ts
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
# deepseek vl2 related
# image
images_emb_mask
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
image_spatial_crop
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
# The id of the single-image placeholder token
im_token_id
:
Optional
[
torch
.
Tensor
]
=
None
im_token_id
:
Optional
[
torch
.
Tensor
]
=
None
# All the images in the batch should share the same special image
# bound token ids.
im_start_id
:
Optional
[
int
]
=
None
im_start_id
:
Optional
[
int
]
=
None
im_end_id
:
Optional
[
int
]
=
None
im_end_id
:
Optional
[
int
]
=
None
slice_start_id
:
Optional
[
int
]
=
None
slice_start_id
:
Optional
[
int
]
=
None
slice_end_id
:
Optional
[
int
]
=
None
slice_end_id
:
Optional
[
int
]
=
None
# [num_images, 2 (w, h)]
tgt_sizes
:
Optional
[
list
]
=
None
# video
video_token_id
:
Optional
[
int
]
=
None
# audio
# audio
audio_start_id
:
Optional
[
torch
.
Tensor
]
=
None
audio_start_id
:
Optional
[
torch
.
Tensor
]
=
None
audio_end_id
:
Optional
[
torch
.
Tensor
]
=
None
audio_end_id
:
Optional
[
torch
.
Tensor
]
=
None
audio_features
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
audio_feature_lens
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
@
staticmethod
@
staticmethod
def
from_dict
(
obj
:
dict
):
def
from_dict
(
obj
:
dict
):
ret
=
MultimodalInputs
(
ret
=
MultimodalInputs
(
pixel_values
=
obj
[
"pixel_values"
],
mm_items
=
obj
[
"mm_items"
],
data_hashes
=
obj
[
"data_hashes"
],
)
)
assert
isinstance
(
ret
.
mm_items
,
list
)
ret
.
mm_items
=
[
item
for
item
in
ret
.
mm_items
if
item
.
is_audio
()
or
item
.
is_image
()
or
item
.
is_video
()
]
assert
len
(
ret
.
mm_items
)
!=
0
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Please note that if the `input_ids` is later used in the model forward,
# Please note that if the `input_ids` is later used in the model forward,
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
# errors in cuda kernels. See also llava.py for example.
# errors in cuda kernels. See also llava.py for example.
ret
.
pad_values
=
[
x
%
(
1
<<
30
)
for
x
in
ret
.
data_hashes
]
for
item
in
ret
.
mm_items
:
item
.
set_pad_value
()
optional_args
=
[
optional_args
=
[
"image_sizes"
,
"modalities"
,
"modalities"
,
"aspect_ratio_ids"
,
"aspect_ratio_mask"
,
"image_grid_thws"
,
"images_emb_mask"
,
"image_spatial_crop"
,
"im_token_id"
,
"im_token_id"
,
"im_start_id"
,
"im_start_id"
,
"im_end_id"
,
"im_end_id"
,
"slice_start_id"
,
"slice_start_id"
,
"slice_end_id"
,
"slice_end_id"
,
"tgt_sizes"
,
"audio_start_id"
,
"audio_start_id"
,
"audio_end_id"
,
"audio_end_id"
,
"audio_features"
,
"audio_feature_lens"
,
]
]
for
arg
in
optional_args
:
for
arg
in
optional_args
:
if
arg
in
obj
:
if
arg
in
obj
:
setattr
(
ret
,
arg
,
obj
[
arg
])
setattr
(
ret
,
arg
,
obj
[
arg
])
# validate
assert
(
isinstance
(
ret
.
pixel_values
,
torch
.
Tensor
)
or
isinstance
(
ret
.
pixel_values
,
np
.
ndarray
)
or
isinstance
(
ret
.
pixel_values
,
list
)
)
assert
ret
.
audio_features
is
None
or
isinstance
(
ret
.
audio_features
,
list
)
return
ret
return
ret
def
contains_image_inputs
(
self
)
->
bool
:
def
contains_image_inputs
(
self
)
->
bool
:
""" """
""" """
return
self
.
pixel_values
is
not
None
and
self
.
pixel_values
!=
[]
return
any
(
item
.
is_image
()
for
item
in
self
.
mm_items
)
def
contains_audio_inputs
(
self
)
->
bool
:
def
contains_audio_inputs
(
self
)
->
bool
:
""" """
""" """
return
self
.
audio_features
is
not
None
and
self
.
audio_features
!=
[]
return
any
(
item
.
is_audio
()
for
item
in
self
.
mm_items
)
def
collect_image_inputs
(
self
)
->
List
[
torch
.
Tensor
]:
return
[
item
.
pixel_values
for
item
in
self
.
mm_items
if
item
.
is_image
()]
def
merge
(
self
,
other
:
MultimodalInputs
):
def
merge
(
self
,
other
:
MultimodalInputs
):
"""
"""
merge image inputs when requests are being merged
merge image inputs when requests are being merged
"""
"""
if
isinstance
(
self
.
pixel_values
,
list
):
# in some rare cases, pixel values are list of patches with different shapes
# e.g. minicpm
self
.
pixel_values
+=
other
.
pixel_values
else
:
assert
(
self
.
pixel_values
.
shape
[
1
:]
==
other
.
pixel_values
.
shape
[
1
:]
),
f
"
{
self
.
pixel_values
.
shape
[
1
:]
}
vs
{
other
.
pixel_values
.
shape
[
1
:]
}
"
self
.
pixel_values
=
np
.
concatenate
([
self
.
pixel_values
,
other
.
pixel_values
])
# args would be stacked along first dim
# usually these are already tensors
stack_args
=
[
# TODO: merge with image_grid_thws, basically the same thing
"tgt_sizes"
,
"image_spatial_crop"
,
]
for
arg
in
stack_args
:
if
getattr
(
self
,
arg
,
None
)
is
None
:
setattr
(
self
,
arg
,
getattr
(
other
,
arg
,
None
))
elif
getattr
(
other
,
arg
,
None
)
is
not
None
:
# self and other both not None
setattr
(
self
,
arg
,
torch
.
cat
([
getattr
(
self
,
arg
),
getattr
(
other
,
arg
)],
dim
=
0
),
)
if
self
.
image_grid_thws
is
None
:
self
.
image_grid_thws
=
other
.
image_grid_thws
elif
other
.
image_grid_thws
is
not
None
:
self
.
image_grid_thws
=
torch
.
concat
(
[
self
.
image_grid_thws
,
other
.
image_grid_thws
]
)
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Please note that if the `input_ids` is later used in the model forward,
# Please note that if the `input_ids` is later used in the model forward,
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
# you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
# errors in cuda kernels. See also llava.py for example.
# errors in cuda kernels. See also llava.py for example.
self
.
data_hashes
+=
other
.
data_hashes
self
.
pad_values
=
[
x
%
(
1
<<
30
)
for
x
in
self
.
data_hashes
]
# args needed to be merged
# args needed to be merged
optional_args
=
[
optional_args
=
[
"audio_features"
,
"items"
,
"image_sizes"
,
"image_offsets"
,
"image_offsets"
,
"image_pad_len"
,
"image_pad_len"
,
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
"aspect_ratio_ids"
,
"aspect_ratio_mask"
,
"images_emb_mask"
,
]
]
for
arg
in
optional_args
:
for
arg
in
optional_args
:
self_arg
=
getattr
(
self
,
arg
,
None
)
self_arg
=
getattr
(
self
,
arg
,
None
)
...
...
python/sglang/srt/managers/scheduler.py
View file @
5cb552b1
...
@@ -112,7 +112,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
...
@@ -112,7 +112,7 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
from
sglang.srt.mem_cache.hiradix_cache
import
HiRadixCache
from
sglang.srt.mem_cache.hiradix_cache
import
HiRadixCache
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.mem_cache.radix_cache
import
RadixCache
from
sglang.srt.metrics.collector
import
SchedulerMetricsCollector
,
SchedulerStats
from
sglang.srt.metrics.collector
import
SchedulerMetricsCollector
,
SchedulerStats
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
,
ForwardMode
from
sglang.srt.model_executor.forward_batch_info
import
ForwardMode
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.speculative.spec_info
import
SpeculativeAlgorithm
from
sglang.srt.torch_memory_saver_adapter
import
TorchMemorySaverAdapter
from
sglang.srt.torch_memory_saver_adapter
import
TorchMemorySaverAdapter
...
...
python/sglang/srt/managers/utils.py
View file @
5cb552b1
import
json
import
logging
import
logging
import
time
from
collections
import
defaultdict
from
http
import
HTTPStatus
from
http
import
HTTPStatus
from
typing
import
Dict
,
List
,
Optional
,
Tuple
from
typing
import
Optional
import
torch
from
sglang.srt.managers.schedule_batch
import
FINISH_ABORT
,
Req
from
sglang.srt.managers.schedule_batch
import
FINISH_ABORT
,
Req
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
5cb552b1
...
@@ -355,11 +355,6 @@ class ForwardBatch:
...
@@ -355,11 +355,6 @@ class ForwardBatch:
for
mm_input
in
valid_inputs
[
1
:]:
for
mm_input
in
valid_inputs
[
1
:]:
merged
.
merge
(
mm_input
)
merged
.
merge
(
mm_input
)
if
isinstance
(
merged
.
pixel_values
,
np
.
ndarray
):
merged
.
pixel_values
=
torch
.
from_numpy
(
merged
.
pixel_values
)
if
isinstance
(
merged
.
audio_features
,
np
.
ndarray
):
merged
.
audio_features
=
torch
.
from_numpy
(
merged
.
audio_features
)
return
merged
return
merged
def
contains_image_inputs
(
self
)
->
bool
:
def
contains_image_inputs
(
self
)
->
bool
:
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
5cb552b1
...
@@ -251,16 +251,15 @@ class ModelRunner:
...
@@ -251,16 +251,15 @@ class ModelRunner:
self
.
init_double_sparsity_channel_config
(
server_args
.
ds_heavy_channel_type
)
self
.
init_double_sparsity_channel_config
(
server_args
.
ds_heavy_channel_type
)
if
self
.
is_multimodal
:
if
self
.
is_multimodal
:
self
.
mem_fraction_static
*=
0.9
5
self
.
mem_fraction_static
*=
0.9
0
logger
.
info
(
logger
.
info
(
f
"Automatically reduce --mem-fraction-static to
{
self
.
mem_fraction_static
:.
3
f
}
"
f
"Automatically reduce --mem-fraction-static to
{
self
.
mem_fraction_static
:.
3
f
}
"
f
"because this is a multimodal model."
f
"because this is a multimodal model."
)
)
if
self
.
model_config
.
hf_config
.
architectures
==
[
logger
.
info
(
"MllamaForConditionalGeneration"
"Automatically turn off --chunked-prefill-size for multimodal model."
]:
)
logger
.
info
(
"Automatically turn off --chunked-prefill-size for mllama."
)
server_args
.
chunked_prefill_size
=
-
1
server_args
.
chunked_prefill_size
=
-
1
if
self
.
model_config
.
hf_config
.
architectures
==
[
if
self
.
model_config
.
hf_config
.
architectures
==
[
...
@@ -269,18 +268,7 @@ class ModelRunner:
...
@@ -269,18 +268,7 @@ class ModelRunner:
"Qwen2_5_VLForConditionalGeneration"
"Qwen2_5_VLForConditionalGeneration"
]:
]:
# TODO: qwen2-vl series does not support radix cache now, set disable_radix_cache=True automatically
# TODO: qwen2-vl series does not support radix cache now, set disable_radix_cache=True automatically
logger
.
info
(
logger
.
info
(
"Automatically disable radix cache for qwen-vl series."
)
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen-vl series."
)
server_args
.
chunked_prefill_size
=
-
1
server_args
.
disable_radix_cache
=
True
if
self
.
model_config
.
hf_config
.
architectures
==
[
"DeepseekVL2ForCausalLM"
]:
# TODO: deepseek-vl2 does not support radix cache now, set disable_radix_cache=True automatically
logger
.
info
(
"Automatically turn off --chunked-prefill-size and disable radix cache for deepseek-vl2."
)
server_args
.
chunked_prefill_size
=
-
1
server_args
.
disable_radix_cache
=
True
server_args
.
disable_radix_cache
=
True
if
server_args
.
enable_deepep_moe
:
if
server_args
.
enable_deepep_moe
:
...
...
python/sglang/srt/models/clip.py
View file @
5cb552b1
...
@@ -17,7 +17,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
...
@@ -17,7 +17,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from
sglang.srt.managers.schedule_batch
import
MultimodalInputs
from
sglang.srt.managers.schedule_batch
import
MultimodalInputs
from
sglang.srt.model_executor.model_runner
import
ForwardBatch
from
sglang.srt.model_executor.model_runner
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
add_prefix
from
sglang.srt.utils
import
add_prefix
,
flatten_nested_list
class
CLIPVisionEmbeddings
(
nn
.
Module
):
class
CLIPVisionEmbeddings
(
nn
.
Module
):
...
@@ -368,7 +368,6 @@ class CLIPVisionTransformer(nn.Module):
...
@@ -368,7 +368,6 @@ class CLIPVisionTransformer(nn.Module):
self
,
self
,
pixel_values
:
torch
.
Tensor
,
pixel_values
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
self
.
embeddings
(
pixel_values
.
to
(
self
.
device
))
hidden_states
=
self
.
embeddings
(
pixel_values
.
to
(
self
.
device
))
hidden_states
=
self
.
pre_layrnorm
(
hidden_states
)
hidden_states
=
self
.
pre_layrnorm
(
hidden_states
)
...
@@ -456,12 +455,18 @@ class CLIPModel(nn.Module):
...
@@ -456,12 +455,18 @@ class CLIPModel(nn.Module):
get_embedding
:
bool
=
True
,
get_embedding
:
bool
=
True
,
):
):
assert
get_embedding
,
"CLIPEmbeddingModel is only used for embedding"
assert
get_embedding
,
"CLIPEmbeddingModel is only used for embedding"
image
_inputs
=
None
mm
_inputs
=
[]
if
forward_batch
.
mm_inputs
is
not
None
:
if
forward_batch
.
mm_inputs
is
not
None
:
image_inputs
=
forward_batch
.
mm_inputs
mm_inputs
=
forward_batch
.
mm_inputs
pixel_values_list
=
[
if
image_inputs
is
not
None
and
image_inputs
[
0
]
is
not
None
:
item
.
pixel_values
vision_outputs
=
self
.
vision_model
(
image_inputs
[
0
].
pixel_values
)
for
item
in
flatten_nested_list
(
[
mm_input
.
mm_items
for
mm_input
in
mm_inputs
if
mm_input
is
not
None
]
)
]
if
len
(
pixel_values_list
)
!=
0
:
pixel_values
=
torch
.
concat
(
pixel_values_list
)
vision_outputs
=
self
.
vision_model
(
pixel_values
)
pooled_output
=
vision_outputs
[:,
0
,
:]
pooled_output
=
vision_outputs
[:,
0
,
:]
image_embeds
=
self
.
visual_projection
(
pooled_output
)
image_embeds
=
self
.
visual_projection
(
pooled_output
)
image_embeds
=
nn
.
functional
.
normalize
(
image_embeds
,
p
=
2
,
dim
=
1
)
image_embeds
=
nn
.
functional
.
normalize
(
image_embeds
,
p
=
2
,
dim
=
1
)
...
...
python/sglang/srt/models/deepseek_janus_pro.py
View file @
5cb552b1
...
@@ -51,7 +51,7 @@ from sglang.srt.managers.mm_utils import (
...
@@ -51,7 +51,7 @@ from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs
,
MultiModalityDataPaddingPatternTokenPairs
,
general_mm_embed_routine
,
general_mm_embed_routine
,
)
)
from
sglang.srt.managers.schedule_batch
import
Multimodal
Inputs
,
global_server_args_dict
from
sglang.srt.managers.schedule_batch
import
Multimodal
DataItem
,
MultimodalInputs
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.models.llama
import
LlamaForCausalLM
from
sglang.srt.models.llama
import
LlamaForCausalLM
...
@@ -1959,8 +1959,8 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
...
@@ -1959,8 +1959,8 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
def
get_image_feature
(
self
,
i
mage_input
:
Multimodal
Inputs
)
->
torch
.
Tensor
:
def
get_image_feature
(
self
,
i
tems
:
List
[
Multimodal
DataItem
]
)
->
torch
.
Tensor
:
pixel_values
=
image_input
.
pixel_values
pixel_values
=
torch
.
concat
([
item
.
pixel_values
for
item
in
items
],
dim
=
0
)
bs
,
n
=
pixel_values
.
shape
[
0
:
2
]
bs
,
n
=
pixel_values
.
shape
[
0
:
2
]
pixel_values
=
pixel_values
.
to
(
pixel_values
=
pixel_values
.
to
(
device
=
self
.
vision_model
.
device
,
dtype
=
self
.
vision_model
.
dtype
device
=
self
.
vision_model
.
device
,
dtype
=
self
.
vision_model
.
dtype
...
@@ -1976,7 +1976,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
...
@@ -1976,7 +1976,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
return
images_embeds
return
images_embeds
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
def
get_input_embeddings
(
self
)
->
nn
.
Embedding
:
return
self
.
language_model
.
model
.
embed_tokens
return
self
.
language_model
.
get_input_embeddings
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
forward
(
def
forward
(
...
@@ -1984,23 +1984,18 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
...
@@ -1984,23 +1984,18 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
input_ids
:
torch
.
LongTensor
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
get_embedding
:
bool
=
False
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
hidden_states
=
general_mm_embed_routine
(
inputs_embeds
=
general_mm_embed_routine
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
forward_batch
=
forward_batch
,
forward_batch
=
forward_batch
,
embed_tokens
=
self
.
get_input_embeddings
(),
image_data_embedding_func
=
self
.
get_image_feature
,
mm_data_embedding_func
=
self
.
get_image_feature
,
language_model
=
self
.
language_model
,
)
return
self
.
language_model
(
input_ids
=
None
,
positions
=
positions
,
positions
=
positions
,
forward_batch
=
forward_batch
,
input_embeds
=
inputs_embeds
,
get_embedding
=
False
,
)
)
return
hidden_states
def
prepare_gen_img_embeds
(
self
,
image_ids
:
torch
.
LongTensor
):
def
prepare_gen_img_embeds
(
self
,
image_ids
:
torch
.
LongTensor
):
return
self
.
gen_aligner
(
self
.
gen_embed
(
image_ids
))
return
self
.
gen_aligner
(
self
.
gen_embed
(
image_ids
))
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment