Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
53475674
Unverified
Commit
53475674
authored
Jul 26, 2025
by
Mick
Committed by
GitHub
Jul 26, 2025
Browse files
chore: improvements on mm_utils (#7737)
parent
ce32bc2b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
33 deletions
+29
-33
python/sglang/srt/managers/mm_utils.py
python/sglang/srt/managers/mm_utils.py
+29
-33
No files found.
python/sglang/srt/managers/mm_utils.py
View file @
53475674
...
...
@@ -85,8 +85,8 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
"No data_token_pairs provided, RadixAttention might be influenced."
)
return
input_ids
start_token_ids
=
[
s
for
s
,
_e
in
data_token_pairs
]
end_tokens_ids
=
[
e
for
_s
,
e
in
data_token_pairs
]
start_token_ids
=
{
s
for
s
,
_e
in
data_token_pairs
}
end_tokens_ids
=
{
e
for
_s
,
e
in
data_token_pairs
}
padded_ids
=
[]
last_idx
=
0
...
...
@@ -135,7 +135,7 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
if
not
input_ids
or
not
mm_inputs
.
mm_items
:
return
input_ids
input_ids_tensor
=
torch
.
tensor
(
input_ids
)
input_ids_tensor
=
torch
.
as_
tensor
(
input_ids
)
# Create mapping of token_ids to pad_values for each modality
token_to_pad_mapping
=
{}
...
...
@@ -211,7 +211,7 @@ def get_embedding_chunk(
end_index
+=
extend_end_index
-
start
+
1
elif
extend_end_index
>
end
:
end_index
+=
end
-
start
+
1
# some models embedding is 3-dim, reshape it to 2-dim
# some models
'
embedding is 3-dim, reshape it to 2-dim
embedding
=
embedding
.
reshape
(
-
1
,
embedding
.
shape
[
-
1
])
embedding_chunk
=
embedding
[
start_index
:
end_index
]
return
embedding_chunk
,
start_index
,
end_index
...
...
@@ -428,7 +428,7 @@ def embed_mm_inputs(
modality_id
=
modality
.
name
.
lower
()
embedder
=
getattr
(
multimodal_model
,
f
"get_
{
modality_id
}
_feature"
,
None
)
if
len
(
items
)
!=
0
and
embedder
is
not
None
:
placeholder_tensor
=
torch
.
tensor
(
placeholder_tensor
=
torch
.
as_
tensor
(
[
item
.
pad_value
for
item
in
items
],
device
=
input_ids
.
device
,
)
...
...
@@ -473,11 +473,9 @@ def embed_mm_inputs(
for
embedding
,
mask
in
zip
(
embeddings
,
masks
):
if
embedding
is
None
or
mask
is
None
:
continue
mask
=
mask
.
expand_as
(
inputs_embeds
).
to
(
inputs_embeds
.
device
)
inputs_embeds
=
inputs_embeds
.
masked_scatter
(
mask
,
embedding
.
to
(
inputs_embeds
.
device
,
inputs_embeds
.
dtype
),
)
# in-place update
indices
=
torch
.
where
(
mask
.
squeeze
(
dim
=-
1
))[
0
]
inputs_embeds
[
indices
]
=
embedding
.
to
(
inputs_embeds
.
device
,
inputs_embeds
.
dtype
)
return
inputs_embeds
...
...
@@ -561,34 +559,36 @@ def get_multimodal_data_bounds(
[bounds_count, 2]
"""
# All the multimodal data in the batch should share the same special bound token ids.
start_tokens
=
[
s
for
s
,
_e
in
token_pairs
]
end_tokens
=
[
e
for
_s
,
e
in
token_pairs
]
start_tokens
=
{
s
for
s
,
_e
in
token_pairs
}
end_tokens
=
{
e
for
_s
,
e
in
token_pairs
}
assert
all
(
isinstance
(
t
,
int
)
for
t
in
start_tokens
)
assert
all
(
isinstance
(
t
,
int
)
for
t
in
end_tokens
)
start_cond
=
torch
.
isin
(
input_ids
,
torch
.
tensor
(
start_tokens
,
device
=
input_ids
.
device
)
input_ids
,
torch
.
as_tensor
(
start_tokens
,
device
=
input_ids
.
device
)
)
end_cond
=
torch
.
isin
(
input_ids
,
torch
.
as_tensor
(
end_tokens
,
device
=
input_ids
.
device
)
)
end_cond
=
torch
.
isin
(
input_ids
,
torch
.
tensor
(
end_tokens
,
device
=
input_ids
.
device
))
(
data_start_tokens
,)
=
torch
.
where
(
start_cond
)
(
data_end_tokens
,)
=
torch
.
where
(
end_cond
)
data_start_tokens_cpu
=
data_start_tokens
.
cpu
().
tolist
()
data_end_tokens_cpu
=
data_end_tokens
.
cpu
().
tolist
()
# the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the multimodal data
if
len
(
data_start_tokens
)
!=
len
(
data_end_tokens
):
if
len
(
data_start_tokens
_cpu
)
!=
len
(
data_end_tokens
_cpu
):
if
(
len
(
data_start_tokens
)
+
1
==
len
(
data_end_tokens
)
and
input_ids
[
0
]
in
pad_values
and
data_end_tokens
[
0
]
<
data_start_tokens
[
0
]
len
(
data_start_tokens_cpu
)
+
1
==
len
(
data_end_tokens_cpu
)
and
input_ids
[
0
].
item
()
in
pad_values
and
data_end_tokens_cpu
and
data_start_tokens_cpu
and
data_end_tokens_cpu
[
0
]
<
data_start_tokens_cpu
[
0
]
):
data_start_tokens
=
torch
.
cat
(
[
torch
.
tensor
([
0
],
device
=
data_start_tokens
.
device
),
data_start_tokens
,
]
)
valid_mm_data_nums
=
min
(
len
(
data_start_tokens
),
len
(
data_end_tokens
))
data_start_tokens_cpu
.
insert
(
0
,
0
)
valid_mm_data_nums
=
min
(
len
(
data_start_tokens_cpu
),
len
(
data_end_tokens_cpu
))
if
valid_mm_data_nums
==
0
:
return
torch
.
zeros
((
0
,
2
),
device
=
input_ids
.
device
)
...
...
@@ -596,8 +596,8 @@ def get_multimodal_data_bounds(
# Filter out pairs where start_token >= end_token
valid_pairs
=
[]
for
i
in
range
(
valid_mm_data_nums
):
start_token
=
data_start_tokens
[
i
]
end_token
=
data_end_tokens
[
i
]
start_token
=
data_start_tokens
_cpu
[
i
]
end_token
=
data_end_tokens
_cpu
[
i
]
if
start_token
<
end_token
:
valid_pairs
.
append
((
start_token
+
1
,
end_token
-
1
))
...
...
@@ -605,7 +605,7 @@ def get_multimodal_data_bounds(
return
torch
.
zeros
((
0
,
2
),
device
=
input_ids
.
device
)
# Convert valid pairs to tensor
valid_pairs_tensor
=
torch
.
tensor
(
valid_pairs
,
device
=
input_ids
.
device
)
valid_pairs_tensor
=
torch
.
as_
tensor
(
valid_pairs
,
device
=
input_ids
.
device
)
return
valid_pairs_tensor
...
...
@@ -634,11 +634,7 @@ def tensor_hash(tensor_list) -> int:
tensor
=
tensor
.
float
()
assert
isinstance
(
tensor
,
torch
.
Tensor
)
if
tensor
.
is_cuda
:
# TODO: improve this
tensor_cpu
=
tensor
.
cpu
()
else
:
tensor_cpu
=
tensor
tensor_cpu
=
tensor
.
cpu
()
mv
=
memoryview
(
tensor_cpu
.
numpy
())
return
data_hash
(
mv
.
tobytes
())
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment