Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
136c6e04
Unverified
Commit
136c6e04
authored
Jul 08, 2025
by
Xinyuan Tong
Committed by
GitHub
Jul 08, 2025
Browse files
fix: Handles input_embeds in GenerateReqInput when n>1 (#7830)
Signed-off-by:
Xinyuan Tong
<
justinning0323@outlook.com
>
parent
43e20c06
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
75 additions
and
1 deletion
+75
-1
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+8
-1
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_io_struct.py
test/srt/test_io_struct.py
+66
-0
No files found.
python/sglang/srt/managers/io_struct.py
View file @
136c6e04
...
...
@@ -200,6 +200,8 @@ class GenerateReqInput:
self
.
text
=
[
self
.
text
]
if
self
.
input_ids
is
not
None
:
self
.
input_ids
=
[
self
.
input_ids
]
if
self
.
input_embeds
is
not
None
:
self
.
input_embeds
=
[
self
.
input_embeds
]
def
_normalize_single_inputs
(
self
):
"""Normalize inputs for a single example."""
...
...
@@ -324,7 +326,9 @@ class GenerateReqInput:
new_rids
=
[
f
"
{
self
.
rid
}
_
{
i
}
"
for
i
in
range
(
num
)]
self
.
rid
=
new_rids
elif
isinstance
(
self
.
rid
,
list
):
if
len
(
self
.
rid
)
!=
num
:
# Note: the length of rid shall be the same as the batch_size,
# as the rid would be expanded for parallel sampling in tokenizer_manager
if
len
(
self
.
rid
)
!=
self
.
batch_size
:
raise
ValueError
(
"The specified rids length mismatch with the batch_size for batch processing."
)
...
...
@@ -400,6 +404,9 @@ class GenerateReqInput:
return
GenerateReqInput
(
text
=
self
.
text
[
i
]
if
self
.
text
is
not
None
else
None
,
input_ids
=
self
.
input_ids
[
i
]
if
self
.
input_ids
is
not
None
else
None
,
input_embeds
=
(
self
.
input_embeds
[
i
]
if
self
.
input_embeds
is
not
None
else
None
),
image_data
=
self
.
image_data
[
i
],
audio_data
=
self
.
audio_data
[
i
],
sampling_params
=
self
.
sampling_params
[
i
],
...
...
test/srt/run_suite.py
View file @
136c6e04
...
...
@@ -67,6 +67,7 @@ suites = {
TestFile
(
"test_hidden_states.py"
,
55
),
TestFile
(
"test_int8_kernel.py"
,
8
),
TestFile
(
"test_input_embeddings.py"
,
38
),
TestFile
(
"test_io_struct.py"
,
8
),
TestFile
(
"test_jinja_template_utils.py"
,
1
),
TestFile
(
"test_metrics.py"
,
32
),
TestFile
(
"test_mla.py"
,
167
),
...
...
test/srt/test_io_struct.py
View file @
136c6e04
...
...
@@ -159,6 +159,7 @@ class TestGenerateReqInputNormalization(CustomTestCase):
"""Test that when some batch items have images and others None, parallel expansion works correctly."""
req
=
copy
.
deepcopy
(
self
.
base_req
)
req
.
text
=
[
"Prompt 1"
,
"Prompt 2"
,
"Prompt 3"
]
req
.
rid
=
[
"id1"
,
"id2"
,
"id3"
]
req
.
image_data
=
[
[
"image1.jpg"
],
None
,
...
...
@@ -311,6 +312,71 @@ class TestGenerateReqInputNormalization(CustomTestCase):
self
.
assertFalse
(
req
.
is_single
)
self
.
assertEqual
(
req
.
batch_size
,
2
)
def
test_input_embeds_with_parallel_sampling
(
self
):
"""Test input_embeds normalization with parallel sampling (n > 1)."""
# Test single input_embeds with parallel sampling
req
=
GenerateReqInput
(
input_embeds
=
[[
0.1
,
0.2
]],
# single embedding vector
sampling_params
=
{
"n"
:
2
},
)
req
.
normalize_batch_and_arguments
()
# Should be converted from single to batch and then expanded
self
.
assertFalse
(
req
.
is_single
)
self
.
assertEqual
(
len
(
req
.
input_embeds
),
2
)
# Both should be the same input_embeds
self
.
assertEqual
(
req
.
input_embeds
[
0
],
[[
0.1
,
0.2
]])
self
.
assertEqual
(
req
.
input_embeds
[
1
],
[[
0.1
,
0.2
]])
# Test batch input_embeds with parallel sampling
req
=
GenerateReqInput
(
input_embeds
=
[[[
0.1
,
0.2
]],
[[
0.3
,
0.4
]]],
sampling_params
=
{
"n"
:
3
}
)
req
.
normalize_batch_and_arguments
()
# Should be expanded
self
.
assertFalse
(
req
.
is_single
)
self
.
assertEqual
(
len
(
req
.
input_embeds
),
6
)
# Check that the expansion is correct
expected_embeds
=
[[[
0.1
,
0.2
]],
[[
0.3
,
0.4
]]]
*
3
self
.
assertEqual
(
req
.
input_embeds
,
expected_embeds
)
# Test with different n values per sample (should raise error)
req
=
GenerateReqInput
(
input_embeds
=
[[[
0.1
,
0.2
]],
[[
0.3
,
0.4
]]],
sampling_params
=
[{
"n"
:
2
},
{
"n"
:
3
}],
)
with
self
.
assertRaises
(
ValueError
):
req
.
normalize_batch_and_arguments
()
def
test_input_embeds_single_to_batch_conversion
(
self
):
"""Test that single input_embeds are properly converted to batch when using parallel sampling."""
# Test the specific case that was fixed: single input_embeds with n > 1
req
=
GenerateReqInput
(
input_embeds
=
[[
0.1
,
0.2
,
0.3
]],
sampling_params
=
{
"n"
:
2
}
# Single embedding
)
req
.
normalize_batch_and_arguments
()
# Should convert single to batch and then expand
self
.
assertFalse
(
req
.
is_single
)
self
.
assertEqual
(
len
(
req
.
input_embeds
),
2
)
# Both should be the same single embedding
self
.
assertEqual
(
req
.
input_embeds
[
0
],
[[
0.1
,
0.2
,
0.3
]])
self
.
assertEqual
(
req
.
input_embeds
[
1
],
[[
0.1
,
0.2
,
0.3
]])
# Test with higher n value
req
=
GenerateReqInput
(
input_embeds
=
[[
0.1
,
0.2
,
0.3
]],
sampling_params
=
{
"n"
:
5
})
req
.
normalize_batch_and_arguments
()
self
.
assertFalse
(
req
.
is_single
)
self
.
assertEqual
(
len
(
req
.
input_embeds
),
5
)
# All should be the same
for
i
in
range
(
5
):
self
.
assertEqual
(
req
.
input_embeds
[
i
],
[[
0.1
,
0.2
,
0.3
]])
def
test_lora_path_normalization
(
self
):
"""Test normalization of lora_path."""
# Test single lora_path with batch input
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment