Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
11577ced
Unverified
Commit
11577ced
authored
Mar 23, 2025
by
Mick
Committed by
GitHub
Mar 22, 2025
Browse files
refactor: bug fixes and refactor for vlm (#4661)
parent
ca75741e
Changes
30
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
536 additions
and
117 deletions
+536
-117
benchmark/mmmu/bench_sglang.py
benchmark/mmmu/bench_sglang.py
+8
-4
python/sglang/srt/configs/janus_pro.py
python/sglang/srt/configs/janus_pro.py
+3
-4
python/sglang/srt/configs/model_config.py
python/sglang/srt/configs/model_config.py
+1
-1
python/sglang/srt/configs/utils.py
python/sglang/srt/configs/utils.py
+25
-0
python/sglang/srt/conversation.py
python/sglang/srt/conversation.py
+1
-1
python/sglang/srt/layers/attention/vision.py
python/sglang/srt/layers/attention/vision.py
+7
-2
python/sglang/srt/managers/image_processor.py
python/sglang/srt/managers/image_processor.py
+15
-9
python/sglang/srt/managers/image_processors/base_image_processor.py
...ang/srt/managers/image_processors/base_image_processor.py
+24
-8
python/sglang/srt/managers/image_processors/deepseek_vl_v2.py
...on/sglang/srt/managers/image_processors/deepseek_vl_v2.py
+15
-20
python/sglang/srt/managers/image_processors/gemma3.py
python/sglang/srt/managers/image_processors/gemma3.py
+4
-21
python/sglang/srt/managers/image_processors/janus_pro.py
python/sglang/srt/managers/image_processors/janus_pro.py
+2
-3
python/sglang/srt/managers/image_processors/llava.py
python/sglang/srt/managers/image_processors/llava.py
+2
-7
python/sglang/srt/managers/image_processors/minicpmv.py
python/sglang/srt/managers/image_processors/minicpmv.py
+45
-5
python/sglang/srt/managers/image_processors/mlama.py
python/sglang/srt/managers/image_processors/mlama.py
+2
-3
python/sglang/srt/managers/image_processors/qwen_vl.py
python/sglang/srt/managers/image_processors/qwen_vl.py
+13
-10
python/sglang/srt/managers/mm_utils.py
python/sglang/srt/managers/mm_utils.py
+303
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+53
-16
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+1
-1
python/sglang/srt/model_executor/forward_batch_info.py
python/sglang/srt/model_executor/forward_batch_info.py
+11
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-1
No files found.
benchmark/mmmu/bench_sglang.py
View file @
11577ced
"""
Bench the sglang-hosted vLM with benchmark MMMU
Bench the sglang-hosted vLM with benchmark MMMU
Usage:
python benchmark/mmmu/bench_sglang.py --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl
Usage:
python benchmark/mmmu/bench_sglang.py --model-path Qwen/Qwen2-VL-7B-Instruct --chat-template qwen2-vl
The eval output will be logged
The eval output will be logged
"""
import
argparse
import
time
import
openai
from
data_utils
import
save_json
...
...
@@ -37,6 +38,7 @@ def eval_mmmu(args):
# had to use an openai server, since SglImage doesn't support image data
client
=
openai
.
Client
(
api_key
=
"sk"
,
base_url
=
f
"http://127.0.0.1:
{
args
.
port
}
/v1"
)
start
=
time
.
time
()
for
i
,
sample
in
enumerate
(
tqdm
(
samples
)):
prompt
=
sample
[
"final_input_prompt"
]
prefix
=
prompt
.
split
(
"<"
)[
0
]
...
...
@@ -73,6 +75,8 @@ def eval_mmmu(args):
response
=
response
.
choices
[
0
].
message
.
content
process_result
(
response
,
sample
,
answer_dict
,
out_samples
)
print
(
f
"Benchmark time:
{
time
.
time
()
-
start
}
"
)
args
.
output_path
=
f
"./val_sglang.json"
save_json
(
args
.
output_path
,
out_samples
)
eval_result
(
model_answer_path
=
args
.
output_path
,
answer_dict
=
answer_dict
)
...
...
python/sglang/srt/configs/janus_pro.py
View file @
11577ced
...
...
@@ -9,8 +9,6 @@ import PIL
import
torch
from
PIL.Image
import
Image
from
transformers
import
(
AutoImageProcessor
,
AutoProcessor
,
BaseImageProcessor
,
BatchFeature
,
LlamaConfig
,
...
...
@@ -20,6 +18,7 @@ from transformers import (
)
from
transformers.image_utils
import
to_numpy_array
from
sglang.srt.configs.utils
import
register_image_processor
,
register_processor
from
sglang.srt.mm_utils
import
expand2square
...
...
@@ -625,5 +624,5 @@ class VLMImageProcessorConfig(PretrainedConfig):
super
().
__init__
(
**
kwargs
)
AutoProcessor
.
registe
r
(
MultiModalityConfig
,
VLChatProcessor
,
exist_ok
=
True
)
AutoImageProcessor
.
register
(
VLMI
mage
P
rocessorConfig
,
None
,
VLMImageProcessor
,
None
)
register_processo
r
(
MultiModalityConfig
,
VLChatProcessor
)
register
_i
mage
_p
rocessor
(
MultiModality
Config
,
VLMImageProcessor
)
python/sglang/srt/configs/model_config.py
View file @
11577ced
...
...
@@ -460,6 +460,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
multimodal_model_archs
=
[
"DeepseekVL2ForCausalLM"
,
"LlavaLlamaForCausalLM"
,
"LlavaQwenForCausalLM"
,
"LlavaMistralForCausalLM"
,
...
...
@@ -472,7 +473,6 @@ multimodal_model_archs = [
"Qwen2_5_VLForConditionalGeneration"
,
"MiniCPMV"
,
"MultiModalityCausalLM"
,
"DeepseekVL2ForCausalLM"
,
]
...
...
python/sglang/srt/configs/utils.py
0 → 100644
View file @
11577ced
from
typing
import
Type
from
transformers
import
(
AutoImageProcessor
,
AutoProcessor
,
BaseImageProcessor
,
PretrainedConfig
,
ProcessorMixin
,
)
def
register_image_processor
(
config
:
Type
[
PretrainedConfig
],
image_processor
:
Type
[
BaseImageProcessor
]
):
"""
register customized hf image processor while removing hf impl
"""
AutoImageProcessor
.
register
(
config
,
None
,
image_processor
,
None
,
exist_ok
=
True
)
def
register_processor
(
config
:
Type
[
PretrainedConfig
],
processor
:
Type
[
ProcessorMixin
]):
"""
register customized hf processor while removing hf impl
"""
AutoProcessor
.
register
(
config
,
processor
,
exist_ok
=
True
)
python/sglang/srt/conversation.py
View file @
11577ced
...
...
@@ -653,7 +653,7 @@ register_conv_template(
Conversation
(
name
=
"gemma-it"
,
system_message
=
"You are a helpful assistant."
,
system_template
=
"<
bos><
start_of_turn>user{system_message}
\n\n
"
,
system_template
=
"<start_of_turn>user{system_message}
\n\n
"
,
roles
=
(
"<start_of_turn>user
\n
"
,
"<start_of_turn>model
\n
"
),
sep
=
"<end_of_turn>
\n
"
,
sep_style
=
SeparatorStyle
.
GEMMA3
,
...
...
python/sglang/srt/layers/attention/vision.py
View file @
11577ced
...
...
@@ -143,9 +143,14 @@ class VisionAttention(nn.Module):
if
position_embeddings
is
not
None
:
cos
,
sin
=
position_embeddings
original_shape
=
q
.
shape
q
,
k
=
q
.
view
(
s
,
head
,
-
1
),
k
.
view
(
s
,
head
,
-
1
)
# [total_tokens, head, head_size]
q
=
q
.
view
(
-
1
,
head
,
self
.
head_size
)
k
=
k
.
view
(
-
1
,
head
,
self
.
head_size
)
q
,
k
=
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
)
q
,
k
=
q
.
reshape
(
original_shape
),
k
.
reshape
(
original_shape
)
q
=
q
.
view
(
original_shape
)
k
=
k
.
view
(
original_shape
)
if
self
.
use_qkv_parallel
:
pass
...
...
python/sglang/srt/managers/image_processor.py
View file @
11577ced
# TODO: also move pad_input_ids into this module
import
importlib
import
inspect
import
logging
import
pkgutil
from
functools
import
lru_cache
from
typing
import
Union
from
torch
import
Tensor
from
transformers
import
IMAGE_PROCESSOR_MAPPING
from
sglang.srt.managers.image_processors.base_image_processor
import
(
...
...
@@ -18,9 +21,7 @@ logger = logging.getLogger(__name__)
IMAGE_PROCESSOR_MAPPING
=
{}
def
get_image_processor
(
hf_config
,
server_args
:
ServerArgs
,
processor
)
->
BaseImageProcessor
:
def
get_image_processor
(
hf_config
,
server_args
,
processor
)
->
BaseImageProcessor
:
for
model_cls
,
processor_cls
in
IMAGE_PROCESSOR_MAPPING
.
items
():
if
model_cls
.
__name__
in
hf_config
.
architectures
:
return
processor_cls
(
hf_config
,
server_args
,
processor
)
...
...
@@ -42,13 +43,18 @@ def import_image_processors():
try
:
module
=
importlib
.
import_module
(
name
)
except
Exception
as
e
:
logger
.
warning
(
f
"Ignore import error when loading
{
name
}
: "
f
"
{
e
}
"
)
logger
.
warning
(
f
"
Ignore import error when loading
{
name
}
: "
f
"
{
e
}
"
)
continue
if
hasattr
(
module
,
"ImageProcessorMapping"
):
entry
=
module
.
ImageProcessorMapping
if
isinstance
(
entry
,
dict
):
for
processor_name
,
cls
in
entry
.
items
():
IMAGE_PROCESSOR_MAPPING
[
processor_name
]
=
cls
all_members
=
inspect
.
getmembers
(
module
,
inspect
.
isclass
)
classes
=
[
member
for
name
,
member
in
all_members
if
member
.
__module__
==
module
.
__name__
]
for
cls
in
classes
:
if
issubclass
(
cls
,
BaseImageProcessor
):
for
arch
in
getattr
(
cls
,
"models"
):
IMAGE_PROCESSOR_MAPPING
[
arch
]
=
cls
# also register processors
...
...
python/sglang/srt/managers/image_processors/base_image_processor.py
View file @
11577ced
...
...
@@ -4,14 +4,14 @@ import dataclasses
import
multiprocessing
as
mp
import
os
from
abc
import
ABC
,
abstractmethod
from
typing
import
Optional
from
typing
import
Optional
,
Union
import
PIL
import
transformers
from
decord
import
VideoReader
,
cpu
from
openai
import
BadRequestError
from
PIL
import
Image
from
sglang.srt.server_args
import
ServerArgs
from
sglang.srt.utils
import
load_image
from
sglang.utils
import
logger
...
...
@@ -31,8 +31,16 @@ class BaseImageProcessorOutput:
# input_text, with each frame of video/image represented as an image_token
input_text
:
str
def
normalize
(
self
):
for
field_name
in
[
"data_hashes"
,
"image_sizes"
,
"all_frames"
]:
field
=
getattr
(
self
,
field_name
,
None
)
if
field
is
not
None
and
isinstance
(
field
,
list
)
and
len
(
field
)
==
0
:
setattr
(
self
,
field_name
,
None
)
class
BaseImageProcessor
(
ABC
):
models
=
[]
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
self
.
hf_config
=
hf_config
self
.
_processor
=
_processor
...
...
@@ -40,6 +48,9 @@ class BaseImageProcessor(ABC):
# FIXME: not accurate, model and image specific
self
.
NUM_TOKEN_PER_FRAME
=
330
# Initialize global processor first
init_global_processor
(
self
,
server_args
)
self
.
executor
=
concurrent
.
futures
.
ProcessPoolExecutor
(
initializer
=
init_global_processor
,
mp_context
=
mp
.
get_context
(
"fork"
),
...
...
@@ -113,7 +124,7 @@ class BaseImageProcessor(ABC):
self
,
input_ids
:
list
[
int
],
image_data
,
image_token
:
str
,
image_token
:
Union
[
int
,
str
]
,
max_req_input_len
:
int
,
return_text
:
Optional
[
bool
]
=
True
,
discard_alpha_channel
:
bool
=
True
,
...
...
@@ -122,9 +133,16 @@ class BaseImageProcessor(ABC):
Each frame of video/image will be replaced by a single image token
Args:
image_token: The token ID representing the image placeholder.
discard_alpha_channel: if True, discards the alpha channel in the returned images
"""
if
isinstance
(
image_token
,
int
):
image_token_str
=
self
.
_processor
.
tokenizer
.
convert_ids_to_tokens
(
image_token
)
else
:
image_token_str
=
image_token
if
isinstance
(
input_ids
,
list
)
and
return_text
:
assert
len
(
input_ids
)
and
isinstance
(
input_ids
[
0
],
int
)
...
...
@@ -190,13 +208,11 @@ class BaseImageProcessor(ABC):
new_text
+=
text_part
except
Exception
as
e
:
import
openai
logger
.
error
(
f
"An exception occurred while loading images:
{
e
}
"
)
raise
BadRequestError
(
f
"An exception occurred while loading images:
{
e
}
"
)
continue
return
BaseImageProcessorOutput
(
image_hashes
=
hashes
,
...
...
@@ -204,6 +220,8 @@ class BaseImageProcessor(ABC):
all_frames
=
images
,
input_text
=
new_text
,
)
out
.
normalize
()
return
out
class
DummyImageProcessor
(
BaseImageProcessor
):
...
...
@@ -214,9 +232,7 @@ class DummyImageProcessor(BaseImageProcessor):
return
None
def
init_global_processor
(
sglang_image_processor
:
BaseImageProcessor
,
server_args
:
ServerArgs
):
def
init_global_processor
(
sglang_image_processor
:
BaseImageProcessor
,
server_args
):
"""Init the global processor for multi-modal models."""
global
global_processor
transformers
.
logging
.
set_verbosity_error
()
...
...
python/sglang/srt/managers/image_processors/deepseek_vl_v2.py
View file @
11577ced
...
...
@@ -16,13 +16,9 @@
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
import
asyncio
import
math
from
typing
import
List
,
Union
import
torch
from
PIL
import
Image
,
ImageOps
from
sglang.srt.managers.image_processor
import
BaseImageProcessor
from
sglang.srt.managers.image_processors.base_image_processor
import
(
...
...
@@ -32,18 +28,24 @@ from sglang.srt.models.deepseek_vl2 import DeepseekVL2ForCausalLM
class
DeepseekVL2ImageProcessor
(
BaseImageProcessor
):
models
=
[
DeepseekVL2ForCausalLM
]
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
# with contextlib.suppress(ValueError):
# AutoProcessor.register("DeepseekVLV2Processor", DeepseekVLV2Processor)
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
self
.
IMAGE_TOKEN
=
"<image>"
@
staticmethod
def
_process_images_task
(
image
,
input_text
,
max_req_input_len
):
return
get_global_processor
().
__call__
(
processor
=
get_global_processor
()
res
=
processor
.
__call__
(
conversations
=
input_text
,
images
=
image
,
max_req_input_len
=
max_req_input_len
)
image_token_id
=
processor
.
image_token_id
res
[
"im_token_id"
]
=
image_token_id
return
res
async
def
_process_images
(
self
,
image_data
,
input_text
,
max_req_input_len
):
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
...
...
@@ -70,18 +72,15 @@ class DeepseekVL2ImageProcessor(BaseImageProcessor):
if
not
isinstance
(
image_data
,
list
):
image_data
=
[
image_data
]
images
,
image_hashes
,
image_sizes
=
[],
[]
,
[]
images
,
image_sizes
=
[],
[]
image_token
=
self
.
IMAGE_TOKEN
base_output
=
self
.
load_images
(
input_ids
,
image_data
,
image_token
,
max_req_input_len
)
base_output
.
all_frames
=
[
img
.
convert
(
"RGB"
)
for
img
in
base_output
.
all_frames
]
res
=
await
self
.
_process_images
(
base_output
.
all_frames
,
base_output
.
input_text
,
max_req_input_len
)
pixel_values
=
res
[
"images"
]
input_ids
=
res
[
"input_ids"
]
images_seq_mask
=
res
[
"images_seq_mask"
]
images_spatial_crop
=
res
[
"images_spatial_crop"
]
batched_images_spatial_crop
=
[]
...
...
@@ -89,16 +88,12 @@ class DeepseekVL2ImageProcessor(BaseImageProcessor):
batched_images_spatial_crop
=
torch
.
stack
(
batched_images_spatial_crop
,
dim
=
0
)
return
{
"input_ids"
:
input_ids
.
tolist
(),
"pixel_values"
:
pixel_values
,
"image_hashes"
:
image_hashes
,
"input_ids"
:
res
[
"input_ids"
].
tolist
(),
"pixel_values"
:
res
[
"images"
],
"im_token_id"
:
res
[
"im_token_id"
],
"image_hashes"
:
base_output
.
image_hashes
,
"image_sizes"
:
image_sizes
,
"image
_seq
_mask"
:
images_seq_mask
,
"image
s_emb
_mask"
:
images_seq_mask
,
"image_spatial_crop"
:
batched_images_spatial_crop
,
"modalities"
:
request_obj
.
modalities
or
[
"image"
],
}
ImageProcessorMapping
=
{
DeepseekVL2ForCausalLM
:
DeepseekVL2ImageProcessor
,
}
python/sglang/srt/managers/image_processors/gemma3.py
View file @
11577ced
...
...
@@ -17,14 +17,15 @@ logger = logging.get_logger(__name__)
class
Gemma3SGLangImageProcessor
(
SGLangBaseImageProcessor
):
models
=
[
Gemma3ForConditionalGeneration
]
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
self
.
IMAGE_TOKEN
=
"<start_of_image>"
self
.
IM_START_TOKEN_ID
=
hf_config
.
boi_token_index
self
.
IM_END_TOKEN_ID
=
hf_config
.
eoi_token_index
@
staticmethod
def
_process_images_task
(
images
,
input_text
,
_hf_config
):
async
def
_process_single_image
(
self
,
images
,
input_text
)
->
dict
:
if
isinstance
(
images
,
list
)
and
len
(
images
)
==
0
:
images
=
None
processor
=
get_global_processor
()
...
...
@@ -46,19 +47,6 @@ class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
"pixel_values"
:
pixel_values
,
}
async
def
_process_images
(
self
,
images
,
input_text
)
->
dict
:
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
return
await
loop
.
run_in_executor
(
self
.
executor
,
Gemma3SGLangImageProcessor
.
_process_images_task
,
images
,
input_text
,
self
.
hf_config
,
)
else
:
return
self
.
_process_images_task
(
images
,
input_text
,
self
.
hf_config
)
async
def
process_images_async
(
self
,
image_data
:
List
[
Union
[
str
,
bytes
]],
...
...
@@ -82,7 +70,7 @@ class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
discard_alpha_channel
=
True
,
)
ret
=
await
self
.
_process_image
s
(
ret
=
await
self
.
_process_
single_
image
(
input_text
=
base_output
.
input_text
,
images
=
base_output
.
all_frames
)
...
...
@@ -93,8 +81,3 @@ class Gemma3SGLangImageProcessor(SGLangBaseImageProcessor):
"im_start_id"
:
self
.
IM_START_TOKEN_ID
,
"im_end_id"
:
self
.
IM_END_TOKEN_ID
,
}
ImageProcessorMapping
=
{
Gemma3ForConditionalGeneration
:
Gemma3SGLangImageProcessor
,
}
python/sglang/srt/managers/image_processors/janus_pro.py
View file @
11577ced
...
...
@@ -11,6 +11,8 @@ from sglang.srt.models.deepseek_janus_pro import MultiModalityCausalLM
class
JanusProProcessor
(
SGLangBaseImageProcessor
):
models
=
[
MultiModalityCausalLM
]
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
...
...
@@ -77,6 +79,3 @@ class JanusProProcessor(SGLangBaseImageProcessor):
"im_end_id"
:
res
[
"im_end_id"
],
"im_token_id"
:
res
[
"im_token_id"
],
}
ImageProcessorMapping
=
{
MultiModalityCausalLM
:
JanusProProcessor
}
python/sglang/srt/managers/image_processors/llava.py
View file @
11577ced
...
...
@@ -15,6 +15,8 @@ from sglang.utils import get_exception_traceback
class
LlavaImageProcessor
(
BaseImageProcessor
):
models
=
[
LlavaVidForCausalLM
,
LlavaQwenForCausalLM
,
LlavaMistralForCausalLM
]
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
...
...
@@ -143,10 +145,3 @@ class LlavaImageProcessor(BaseImageProcessor):
"image_sizes"
:
image_sizes
,
"modalities"
:
request_obj
.
modalities
or
[
"image"
],
}
ImageProcessorMapping
=
{
LlavaVidForCausalLM
:
LlavaImageProcessor
,
LlavaQwenForCausalLM
:
LlavaImageProcessor
,
LlavaMistralForCausalLM
:
LlavaImageProcessor
,
}
python/sglang/srt/managers/image_processors/minicpmv.py
View file @
11577ced
import
asyncio
from
typing
import
List
,
Union
import
torch
from
sglang.srt.managers.image_processor
import
BaseImageProcessor
from
sglang.srt.managers.image_processors.base_image_processor
import
(
get_global_processor
,
...
...
@@ -9,6 +11,8 @@ from sglang.srt.models.minicpmv import MiniCPMV
class
MiniCPMVImageProcessor
(
BaseImageProcessor
):
models
=
[
MiniCPMV
]
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
self
.
IMAGE_TOKEN
=
"(<image>./</image>)"
...
...
@@ -69,21 +73,57 @@ class MiniCPMVImageProcessor(BaseImageProcessor):
# Collect special token ids
tokenizer
=
self
.
_processor
.
tokenizer
im_start_id
=
tokenizer
.
im_start_id
im_token_id
=
tokenizer
.
unk_token_id
im_end_id
=
tokenizer
.
im_end_id
if
tokenizer
.
slice_start_id
:
slice_start_id
=
tokenizer
.
slice_start_id
slice_end_id
=
tokenizer
.
slice_end_id
pixel_values
=
res
[
"pixel_values"
]
tgt_sizes
=
res
[
"tgt_sizes"
]
if
not
isinstance
(
pixel_values
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of pixel values. "
f
"Got type:
{
type
(
pixel_values
)
}
"
)
if
not
isinstance
(
tgt_sizes
,
(
torch
.
Tensor
,
list
)):
raise
ValueError
(
"Incorrect type of target sizes. "
f
"Got type:
{
type
(
tgt_sizes
)
}
"
)
if
len
(
pixel_values
)
!=
len
(
tgt_sizes
):
raise
ValueError
(
"Inconsistent batch lengths, found: "
f
"
{
len
(
pixel_values
)
}
vs.
{
len
(
tgt_sizes
)
}
"
)
# tgt_sizes = [tgt_size for tgt_size in tgt_sizes if isinstance(tgt_size, torch.Tensor)]
# tgt_sizes = torch.vstack(tgt_sizes).type(torch.int32)
pixel_values_flat
:
List
[
torch
.
Tensor
]
=
[]
tgt_sizes_flat
:
List
[
torch
.
Tensor
]
=
[]
for
pixel_b
,
tgt_b
in
zip
(
pixel_values
,
tgt_sizes
):
# per image
if
len
(
pixel_b
)
!=
len
(
tgt_b
):
raise
ValueError
(
"Inconsistent N lengths, found: "
f
"
{
len
(
pixel_b
)
}
vs
{
len
(
tgt_b
)
}
"
)
for
pixel_n
,
tgt_n
in
zip
(
pixel_b
,
tgt_b
):
# per patch
pixel_values_flat
+=
[
pixel_n
]
tgt_sizes_flat
+=
[
tgt_n
]
pixel_values
=
pixel_values_flat
tgt_sizes
=
torch
.
stack
(
tgt_sizes_flat
)
return
{
"input_ids"
:
res
[
"input_ids"
].
flatten
().
tolist
(),
"pixel_values"
:
res
[
"
pixel_values
"
]
,
"tgt_sizes"
:
res
[
"
tgt_sizes
"
]
,
"pixel_values"
:
pixel_values
,
"tgt_sizes"
:
tgt_sizes
,
"image_hashes"
:
base_output
.
image_hashes
,
"modalities"
:
request_obj
.
modalities
or
[
"image"
],
"im_start_id"
:
im_start_id
,
"im_token_id"
:
im_token_id
,
"im_end_id"
:
im_end_id
,
"slice_start_id"
:
slice_start_id
,
"slice_end_id"
:
slice_end_id
,
}
ImageProcessorMapping
=
{
MiniCPMV
:
MiniCPMVImageProcessor
}
python/sglang/srt/managers/image_processors/mlama.py
View file @
11577ced
...
...
@@ -10,6 +10,8 @@ from sglang.srt.utils import load_image
class
MllamaImageProcessor
(
BaseImageProcessor
):
models
=
[
MllamaForConditionalGeneration
]
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
...
...
@@ -55,6 +57,3 @@ class MllamaImageProcessor(BaseImageProcessor):
image_inputs
[
"input_ids"
]
=
image_inputs
[
"input_ids"
].
tolist
()[
0
]
return
image_inputs
ImageProcessorMapping
=
{
MllamaForConditionalGeneration
:
MllamaImageProcessor
}
python/sglang/srt/managers/image_processors/qwen_vl.py
View file @
11577ced
...
...
@@ -2,6 +2,7 @@ import asyncio
import
math
from
typing
import
List
,
Union
import
torch
from
PIL
import
Image
from
sglang.srt.managers.image_processor
import
BaseImageProcessor
...
...
@@ -14,6 +15,8 @@ from sglang.srt.models.qwen2_vl import Qwen2VLForConditionalGeneration
# Compatible with Qwen2VL and Qwen2_5VL
class
Qwen2_5VLImageProcessor
(
BaseImageProcessor
):
models
=
[
Qwen2VLForConditionalGeneration
,
Qwen2_5_VLForConditionalGeneration
]
def
__init__
(
self
,
hf_config
,
server_args
,
_processor
):
super
().
__init__
(
hf_config
,
server_args
,
_processor
)
self
.
IMAGE_TOKEN
=
"<|vision_start|><|image_pad|><|vision_end|>"
...
...
@@ -43,7 +46,7 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
"video_grid_thws"
:
getattr
(
result
,
"video_grid_thws"
,
None
),
}
async
def
_process_image
s
(
self
,
images
,
input_text
)
->
dict
:
async
def
_process_
single_
image
(
self
,
images
,
input_text
)
->
dict
:
if
self
.
executor
is
not
None
:
loop
=
asyncio
.
get_event_loop
()
return
await
loop
.
run_in_executor
(
...
...
@@ -138,23 +141,23 @@ class Qwen2_5VLImageProcessor(BaseImageProcessor):
images
=
[
resize_image
(
image
)
for
image
in
base_output
.
all_frames
]
ret
=
await
self
.
_process_images
(
images
,
base_output
.
input_text
)
ret
=
await
self
.
_process_single_image
(
images
=
images
,
input_text
=
base_output
.
input_text
)
image_grid_thws
=
torch
.
concat
([
ret
[
"image_grid_thw"
]])
video_grid_thws
=
None
return
{
"input_ids"
:
ret
[
"input_ids"
].
flatten
().
tolist
(),
"pixel_values"
:
ret
[
"pixel_values"
],
"image_hashes"
:
base_output
.
image_hashes
,
"modalities"
:
request_obj
.
modalities
or
[
"image"
],
"image_grid_thws"
:
ret
[
"
image_grid_thw
"
]
,
"video_grid_thws"
:
ret
[
"
video_grid_thws
"
]
,
"image_grid_thws"
:
image_grid_thw
s
,
"video_grid_thws"
:
video_grid_thws
,
"im_start_id"
:
self
.
IM_START_TOKEN_ID
,
"im_end_id"
:
self
.
IM_END_TOKEN_ID
,
"im_token_id"
:
self
.
image_token_id
,
"video_token_id"
:
self
.
video_token_id
,
"second_per_grid_ts"
:
ret
[
"second_per_grid_ts"
],
}
ImageProcessorMapping
=
{
Qwen2VLForConditionalGeneration
:
Qwen2_5VLImageProcessor
,
Qwen2_5_VLForConditionalGeneration
:
Qwen2_5VLImageProcessor
,
}
python/sglang/srt/managers/m
ulti_modality_padding
.py
→
python/sglang/srt/managers/m
m_utils
.py
View file @
11577ced
"""
Multimodality utils
"""
from
abc
import
abstractmethod
from
typing
import
Callable
,
List
,
Optional
,
Tuple
from
sglang.srt.managers.schedule_batch
import
ImageInputs
import
torch
from
torch
import
nn
from
sglang.srt.managers.schedule_batch
import
(
ImageInputs
,
global_server_args_dict
,
logger
,
)
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.utils
import
logger
...
...
@@ -115,7 +127,6 @@ class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern)
input_ids_with_image
=
[]
for
image_cnt
,
_
in
enumerate
(
image_grid_thws
):
print
(
f
"image_cnt
{
image_cnt
}
"
)
num_image_tokens
=
self
.
num_data_token_calc_func
(
image_grid_thws
[
image_cnt
])
if
image_cnt
==
0
:
non_image_tokens
=
input_ids
[:
image_indices
[
image_cnt
]]
...
...
@@ -132,3 +143,161 @@ class MultModalityDataPaddingPatternSingleToken(MultiModalityDataPaddingPattern)
input_ids_with_image
.
extend
(
input_ids
[
image_indices
[
-
1
]
+
1
:])
return
input_ids_with_image
class
MultiModalityDataPaddingPatternImageTokens
(
MultiModalityDataPaddingPattern
):
"""In this pattern, data tokens should be represented as image tokens (e.g. <image><image>....<image>)"""
def
__init__
(
self
,
image_token_id
:
torch
.
Tensor
)
->
None
:
self
.
image_token_id
=
image_token_id
def
pad_input_tokens
(
self
,
input_ids
:
List
[
int
],
image_inputs
)
->
List
[
int
]:
"""
This function will replace the data-tokens in between with pad_values accordingly
"""
pad_values
=
image_inputs
.
pad_values
assert
len
(
pad_values
)
!=
0
input_ids_tensor
=
torch
.
tensor
(
input_ids
)
mask
=
torch
.
isin
(
input_ids_tensor
,
self
.
image_token_id
)
num_image_tokens
=
mask
.
sum
().
item
()
repeated_pad_values
=
torch
.
tensor
(
pad_values
).
repeat
(
num_image_tokens
//
len
(
pad_values
)
+
1
)[:
num_image_tokens
]
input_ids_tensor
[
mask
]
=
repeated_pad_values
return
input_ids_tensor
.
tolist
()
def
embed_image_inputs
(
image_input
:
ImageInputs
,
input_ids
:
torch
.
Tensor
,
input_embedding
:
nn
.
Embedding
,
image_embedding_func
,
placeholder_token_ids
:
List
[
int
]
=
None
,
)
->
Optional
[
torch
.
Tensor
]:
"""
Calculate the image embeddings if necessary, then scatter the result with
the help of a boolean mask denoting the embed locations
Returns:
final embedding: Optional[torch.Tensor]
"""
if
image_input
is
None
:
return
None
placeholder_token_ids
=
placeholder_token_ids
or
image_input
.
pad_values
# boolean masking the special tokens
special_image_mask
=
torch
.
isin
(
input_ids
,
torch
.
tensor
(
placeholder_token_ids
,
device
=
input_ids
.
device
),
).
unsqueeze
(
-
1
)
num_image_tokens_in_input_ids
=
special_image_mask
.
sum
()
if
num_image_tokens_in_input_ids
==
0
:
# unexpected
inputs_embeds
=
input_embedding
(
input_ids
)
else
:
image_embedding
=
image_embedding_func
(
image_input
)
if
image_embedding
.
dim
()
==
2
:
num_image_tokens_in_embedding
=
image_embedding
.
shape
[
0
]
else
:
num_image_tokens_in_embedding
=
(
image_embedding
.
shape
[
0
]
*
image_embedding
.
shape
[
1
]
)
if
num_image_tokens_in_input_ids
!=
num_image_tokens_in_embedding
:
num_image
=
num_image_tokens_in_input_ids
//
image_embedding
.
shape
[
1
]
image_embedding
=
image_embedding
[:
num_image
,
:]
logger
.
warning
(
f
"Number of images does not match number of special image tokens in the input text. "
f
"Got
{
num_image_tokens_in_input_ids
}
image tokens in the text but
{
num_image_tokens_in_embedding
}
"
"tokens from image embeddings."
)
# TODO: chunked prefill will split special tokens from input_ids into several passes, failing the embedding
# a fix may be cache the unfinished image embedding for future reuse, determine the tokens to embed with
# extend_start_loc and extend_seq_lens
if
num_image_tokens_in_input_ids
>
num_image_tokens_in_embedding
:
chunked_prefill_size
=
global_server_args_dict
[
"chunked_prefill_size"
]
if
chunked_prefill_size
!=
-
1
:
logger
.
warning
(
"You may want to avoid this issue by raising `chunked_prefill_size`, or disabling chunked_prefill"
)
vocab_size
=
input_embedding
.
num_embeddings
# Important: clamp after getting original image regions
# Clamp input ids. This is because the input_ids for the image tokens are
# filled with the hash values of the image for the prefix matching in the radix attention.
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
input_ids
.
clamp_
(
min
=
0
,
max
=
vocab_size
-
1
)
inputs_embeds
=
input_embedding
(
input_ids
)
special_image_mask
=
special_image_mask
.
expand_as
(
inputs_embeds
).
to
(
inputs_embeds
.
device
)
inputs_embeds
=
inputs_embeds
.
masked_scatter
(
special_image_mask
,
image_embedding
.
to
(
inputs_embeds
.
device
,
inputs_embeds
.
dtype
),
)
return
inputs_embeds
def
embed_image_embedding
(
inputs_embeds
:
torch
.
Tensor
,
image_embedding
:
torch
.
Tensor
,
image_bounds
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
"""
scatter image_embedding into inputs_embeds according to image_bounds
"""
if
len
(
image_bounds
)
>
0
:
image_indices
=
torch
.
stack
(
[
torch
.
arange
(
start
,
end
,
dtype
=
torch
.
long
)
for
start
,
end
in
image_bounds
.
tolist
()
]
).
to
(
inputs_embeds
.
device
)
inputs_embeds
.
scatter_
(
0
,
image_indices
.
view
(
-
1
,
1
).
repeat
(
1
,
inputs_embeds
.
shape
[
-
1
]),
image_embedding
.
view
(
-
1
,
image_embedding
.
shape
[
-
1
]),
)
return
inputs_embeds
def
general_mm_embed_routine
(
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
forward_batch
:
ForwardBatch
,
embed_tokens
:
nn
.
Embedding
,
image_embedding_func
:
Callable
[[
ImageInputs
],
torch
.
Tensor
],
placeholder_token_ids
:
List
[
int
]
=
None
,
):
"""
a general wrapper function to get final input embeds from multimodal models
with a language model as causal model
"""
if
(
forward_batch
.
forward_mode
.
is_decode
()
or
not
forward_batch
.
contains_image_inputs
()
):
inputs_embeds
=
embed_tokens
(
input_ids
)
else
:
image
=
forward_batch
.
merge_image_inputs
()
inputs_embeds
=
embed_image_inputs
(
image_input
=
image
,
input_ids
=
input_ids
,
input_embedding
=
embed_tokens
,
image_embedding_func
=
image_embedding_func
,
placeholder_token_ids
=
placeholder_token_ids
,
)
# once used, image_inputs is useless
# just being defensive here
forward_batch
.
image_inputs
=
None
return
inputs_embeds
python/sglang/srt/managers/schedule_batch.py
View file @
11577ced
...
...
@@ -77,6 +77,7 @@ global_server_args_dict = {
"enable_flashmla"
:
ServerArgs
.
enable_flashmla
,
"disable_radix_cache"
:
ServerArgs
.
disable_radix_cache
,
"flashinfer_mla_disable_ragged"
:
ServerArgs
.
flashinfer_mla_disable_ragged
,
"chunked_prefill_size"
:
ServerArgs
.
chunked_prefill_size
,
}
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -160,7 +161,8 @@ class ImageInputs:
aspect_ratio_mask
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
# QWen2-VL related
image_grid_thws
:
List
[
Tuple
[
int
,
int
,
int
]]
=
None
# [num_of_images, t, h, w]
image_grid_thws
:
torch
.
Tensor
=
None
mrope_position_delta
:
Optional
[
torch
.
Tensor
]
=
None
# Qwen2-VL video related
video_token_id
:
Optional
[
int
]
=
None
...
...
@@ -168,7 +170,7 @@ class ImageInputs:
second_per_grid_ts
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
# deepseek vl2 related
image
_seq
_mask
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
image
s_emb
_mask
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
image_spatial_crop
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
# The id of the single-image placeholder token
...
...
@@ -182,9 +184,6 @@ class ImageInputs:
slice_end_id
:
Optional
[
int
]
=
None
tgt_sizes
:
Optional
[
list
]
=
None
# denotes the number of valid image tokens in each image
images_emb_mask
:
Optional
[
torch
.
BoolTensor
]
=
None
@
staticmethod
def
from_dict
(
obj
:
dict
):
ret
=
ImageInputs
(
...
...
@@ -204,7 +203,7 @@ class ImageInputs:
"aspect_ratio_ids"
,
"aspect_ratio_mask"
,
"image_grid_thws"
,
"image
_seq
_mask"
,
"image
s_emb
_mask"
,
"image_spatial_crop"
,
"im_token_id"
,
"im_start_id"
,
...
...
@@ -212,20 +211,58 @@ class ImageInputs:
"slice_start_id"
,
"slice_end_id"
,
"tgt_sizes"
,
"images_emb_mask"
,
]
for
arg
in
optional_args
:
if
arg
in
obj
:
setattr
(
ret
,
arg
,
obj
[
arg
])
# validate
assert
(
isinstance
(
ret
.
pixel_values
,
torch
.
Tensor
)
or
isinstance
(
ret
.
pixel_values
,
np
.
ndarray
)
or
isinstance
(
ret
.
pixel_values
,
list
)
)
return
ret
def
merge
(
self
,
other
):
def
merge
(
self
,
other
:
ImageInputs
):
"""
merge image inputs when requests are being merged
"""
assert
self
.
pixel_values
.
shape
[
1
:]
==
other
.
pixel_values
.
shape
[
1
:]
self
.
pixel_values
=
np
.
concatenate
([
self
.
pixel_values
,
other
.
pixel_values
])
if
isinstance
(
self
.
pixel_values
,
list
):
# in some rare cases, pixel values are list of patches with different shapes
# e.g. minicpm
self
.
pixel_values
+=
other
.
pixel_values
else
:
assert
(
self
.
pixel_values
.
shape
[
1
:]
==
other
.
pixel_values
.
shape
[
1
:]
),
f
"
{
self
.
pixel_values
.
shape
[
1
:]
}
vs
{
other
.
pixel_values
.
shape
[
1
:]
}
"
self
.
pixel_values
=
np
.
concatenate
([
self
.
pixel_values
,
other
.
pixel_values
])
# args would be stacked along first dim
# usually these are already tensors
stack_args
=
[
# TODO: merge with image_grid_thws, basically the same thing
"tgt_sizes"
,
"image_spatial_crop"
,
]
for
arg
in
stack_args
:
if
getattr
(
self
,
arg
,
None
)
is
None
:
setattr
(
self
,
arg
,
getattr
(
other
,
arg
,
None
))
elif
getattr
(
other
,
arg
,
None
)
is
not
None
:
# self and other both not None
setattr
(
self
,
arg
,
torch
.
cat
([
getattr
(
self
,
arg
),
getattr
(
other
,
arg
)],
dim
=
0
),
)
if
self
.
image_grid_thws
is
None
:
self
.
image_grid_thws
=
other
.
image_grid_thws
elif
other
.
image_grid_thws
is
not
None
:
self
.
image_grid_thws
=
torch
.
concat
(
[
self
.
image_grid_thws
,
other
.
image_grid_thws
]
)
# Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
# Please note that if the `input_ids` is later used in the model forward,
...
...
@@ -233,7 +270,7 @@ class ImageInputs:
# errors in cuda kernels. See also llava.py for example.
self
.
image_hashes
+=
other
.
image_hashes
self
.
pad_values
=
[
x
%
(
1
<<
30
)
for
x
in
self
.
image_hashes
]
# args needed to be merged
optional_args
=
[
"image_sizes"
,
"image_offsets"
,
...
...
@@ -241,13 +278,13 @@ class ImageInputs:
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
"aspect_ratio_ids"
,
"aspect_ratio_mask"
,
"image_grid_thws"
,
"image_seq_mask"
,
"image_spatial_crop"
,
"images_emb_mask"
,
]
for
arg
in
optional_args
:
if
getattr
(
self
,
arg
,
None
)
is
not
None
:
setattr
(
self
,
arg
,
getattr
(
self
,
arg
)
+
getattr
(
other
,
arg
))
self_arg
=
getattr
(
self
,
arg
,
None
)
if
self_arg
is
not
None
:
setattr
(
self
,
arg
,
self_arg
+
getattr
(
other
,
arg
))
# other args would be kept intact
class
Req
:
...
...
python/sglang/srt/managers/tokenizer_manager.py
View file @
11577ced
...
...
@@ -179,7 +179,7 @@ class TokenizerManager:
)
# We want to parallelize the image pre-processing so we create an executor for it
# We creat image_processor for any skip_tokenizer_init to make sure we still encode
# We creat
e
image_processor for any skip_tokenizer_init to make sure we still encode
# images even with skip_tokenizer_init=False.
self
.
image_processor
=
get_image_processor
(
self
.
model_config
.
hf_config
,
server_args
,
_processor
...
...
python/sglang/srt/model_executor/forward_batch_info.py
View file @
11577ced
...
...
@@ -332,7 +332,7 @@ class ForwardBatch:
return
ret
def
get_
merge
d
_image_inputs
(
self
)
->
Optional
[
ImageInputs
]:
def
merge_image_inputs
(
self
)
->
Optional
[
ImageInputs
]:
"""
Merge all image inputs in the batch into a single ImageInputs object.
...
...
@@ -358,6 +358,16 @@ class ForwardBatch:
return
merged
def
contains_image_inputs
(
self
)
->
bool
:
""" """
if
self
.
image_inputs
is
None
:
return
True
return
any
(
image_input
.
pixel_values
is
not
None
and
image_input
.
pixel_values
is
not
[]
for
image_input
in
self
.
image_inputs
if
image_input
is
not
None
)
def
_compute_mrope_positions
(
self
,
model_runner
:
ModelRunner
,
batch
:
ModelWorkerBatch
):
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
11577ced
...
...
@@ -273,7 +273,7 @@ class ModelRunner:
if
self
.
model_config
.
hf_config
.
architectures
==
[
"DeepseekVL2ForCausalLM"
]:
# TODO: deepseek-vl2 does not support radix cache now, set disable_radix_cache=True automatically
logger
.
info
(
"Automatically turn off --chunked-prefill-size and disable radix cache for dee
k
seek-vl2."
"Automatically turn off --chunked-prefill-size and disable radix cache for dee
p
seek-vl2."
)
server_args
.
chunked_prefill_size
=
-
1
server_args
.
disable_radix_cache
=
True
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment