Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
136c6e04
Unverified
Commit
136c6e04
authored
Jul 08, 2025
by
Xinyuan Tong
Committed by
GitHub
Jul 08, 2025
Browse files
fix: Handles input_embeds in GenerateReqInput when n>1 (#7830)
Signed-off-by:
Xinyuan Tong
<
justinning0323@outlook.com
>
parent
43e20c06
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
75 additions
and
1 deletion
+75
-1
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+8
-1
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_io_struct.py
test/srt/test_io_struct.py
+66
-0
No files found.
python/sglang/srt/managers/io_struct.py
View file @
136c6e04
...
@@ -200,6 +200,8 @@ class GenerateReqInput:
...
@@ -200,6 +200,8 @@ class GenerateReqInput:
self
.
text
=
[
self
.
text
]
self
.
text
=
[
self
.
text
]
if
self
.
input_ids
is
not
None
:
if
self
.
input_ids
is
not
None
:
self
.
input_ids
=
[
self
.
input_ids
]
self
.
input_ids
=
[
self
.
input_ids
]
if
self
.
input_embeds
is
not
None
:
self
.
input_embeds
=
[
self
.
input_embeds
]
def
_normalize_single_inputs
(
self
):
def
_normalize_single_inputs
(
self
):
"""Normalize inputs for a single example."""
"""Normalize inputs for a single example."""
...
@@ -324,7 +326,9 @@ class GenerateReqInput:
...
@@ -324,7 +326,9 @@ class GenerateReqInput:
new_rids
=
[
f
"
{
self
.
rid
}
_
{
i
}
"
for
i
in
range
(
num
)]
new_rids
=
[
f
"
{
self
.
rid
}
_
{
i
}
"
for
i
in
range
(
num
)]
self
.
rid
=
new_rids
self
.
rid
=
new_rids
elif
isinstance
(
self
.
rid
,
list
):
elif
isinstance
(
self
.
rid
,
list
):
if
len
(
self
.
rid
)
!=
num
:
# Note: the length of rid shall be the same as the batch_size,
# as the rid would be expanded for parallel sampling in tokenizer_manager
if
len
(
self
.
rid
)
!=
self
.
batch_size
:
raise
ValueError
(
raise
ValueError
(
"The specified rids length mismatch with the batch_size for batch processing."
"The specified rids length mismatch with the batch_size for batch processing."
)
)
...
@@ -400,6 +404,9 @@ class GenerateReqInput:
...
@@ -400,6 +404,9 @@ class GenerateReqInput:
return
GenerateReqInput
(
return
GenerateReqInput
(
text
=
self
.
text
[
i
]
if
self
.
text
is
not
None
else
None
,
text
=
self
.
text
[
i
]
if
self
.
text
is
not
None
else
None
,
input_ids
=
self
.
input_ids
[
i
]
if
self
.
input_ids
is
not
None
else
None
,
input_ids
=
self
.
input_ids
[
i
]
if
self
.
input_ids
is
not
None
else
None
,
input_embeds
=
(
self
.
input_embeds
[
i
]
if
self
.
input_embeds
is
not
None
else
None
),
image_data
=
self
.
image_data
[
i
],
image_data
=
self
.
image_data
[
i
],
audio_data
=
self
.
audio_data
[
i
],
audio_data
=
self
.
audio_data
[
i
],
sampling_params
=
self
.
sampling_params
[
i
],
sampling_params
=
self
.
sampling_params
[
i
],
...
...
test/srt/run_suite.py
View file @
136c6e04
...
@@ -67,6 +67,7 @@ suites = {
...
@@ -67,6 +67,7 @@ suites = {
TestFile
(
"test_hidden_states.py"
,
55
),
TestFile
(
"test_hidden_states.py"
,
55
),
TestFile
(
"test_int8_kernel.py"
,
8
),
TestFile
(
"test_int8_kernel.py"
,
8
),
TestFile
(
"test_input_embeddings.py"
,
38
),
TestFile
(
"test_input_embeddings.py"
,
38
),
TestFile
(
"test_io_struct.py"
,
8
),
TestFile
(
"test_jinja_template_utils.py"
,
1
),
TestFile
(
"test_jinja_template_utils.py"
,
1
),
TestFile
(
"test_metrics.py"
,
32
),
TestFile
(
"test_metrics.py"
,
32
),
TestFile
(
"test_mla.py"
,
167
),
TestFile
(
"test_mla.py"
,
167
),
...
...
test/srt/test_io_struct.py
View file @
136c6e04
...
@@ -159,6 +159,7 @@ class TestGenerateReqInputNormalization(CustomTestCase):
...
@@ -159,6 +159,7 @@ class TestGenerateReqInputNormalization(CustomTestCase):
"""Test that when some batch items have images and others None, parallel expansion works correctly."""
"""Test that when some batch items have images and others None, parallel expansion works correctly."""
req
=
copy
.
deepcopy
(
self
.
base_req
)
req
=
copy
.
deepcopy
(
self
.
base_req
)
req
.
text
=
[
"Prompt 1"
,
"Prompt 2"
,
"Prompt 3"
]
req
.
text
=
[
"Prompt 1"
,
"Prompt 2"
,
"Prompt 3"
]
req
.
rid
=
[
"id1"
,
"id2"
,
"id3"
]
req
.
image_data
=
[
req
.
image_data
=
[
[
"image1.jpg"
],
[
"image1.jpg"
],
None
,
None
,
...
@@ -311,6 +312,71 @@ class TestGenerateReqInputNormalization(CustomTestCase):
...
@@ -311,6 +312,71 @@ class TestGenerateReqInputNormalization(CustomTestCase):
self
.
assertFalse
(
req
.
is_single
)
self
.
assertFalse
(
req
.
is_single
)
self
.
assertEqual
(
req
.
batch_size
,
2
)
self
.
assertEqual
(
req
.
batch_size
,
2
)
def
test_input_embeds_with_parallel_sampling
(
self
):
"""Test input_embeds normalization with parallel sampling (n > 1)."""
# Test single input_embeds with parallel sampling
req
=
GenerateReqInput
(
input_embeds
=
[[
0.1
,
0.2
]],
# single embedding vector
sampling_params
=
{
"n"
:
2
},
)
req
.
normalize_batch_and_arguments
()
# Should be converted from single to batch and then expanded
self
.
assertFalse
(
req
.
is_single
)
self
.
assertEqual
(
len
(
req
.
input_embeds
),
2
)
# Both should be the same input_embeds
self
.
assertEqual
(
req
.
input_embeds
[
0
],
[[
0.1
,
0.2
]])
self
.
assertEqual
(
req
.
input_embeds
[
1
],
[[
0.1
,
0.2
]])
# Test batch input_embeds with parallel sampling
req
=
GenerateReqInput
(
input_embeds
=
[[[
0.1
,
0.2
]],
[[
0.3
,
0.4
]]],
sampling_params
=
{
"n"
:
3
}
)
req
.
normalize_batch_and_arguments
()
# Should be expanded
self
.
assertFalse
(
req
.
is_single
)
self
.
assertEqual
(
len
(
req
.
input_embeds
),
6
)
# Check that the expansion is correct
expected_embeds
=
[[[
0.1
,
0.2
]],
[[
0.3
,
0.4
]]]
*
3
self
.
assertEqual
(
req
.
input_embeds
,
expected_embeds
)
# Test with different n values per sample (should raise error)
req
=
GenerateReqInput
(
input_embeds
=
[[[
0.1
,
0.2
]],
[[
0.3
,
0.4
]]],
sampling_params
=
[{
"n"
:
2
},
{
"n"
:
3
}],
)
with
self
.
assertRaises
(
ValueError
):
req
.
normalize_batch_and_arguments
()
def
test_input_embeds_single_to_batch_conversion
(
self
):
"""Test that single input_embeds are properly converted to batch when using parallel sampling."""
# Test the specific case that was fixed: single input_embeds with n > 1
req
=
GenerateReqInput
(
input_embeds
=
[[
0.1
,
0.2
,
0.3
]],
sampling_params
=
{
"n"
:
2
}
# Single embedding
)
req
.
normalize_batch_and_arguments
()
# Should convert single to batch and then expand
self
.
assertFalse
(
req
.
is_single
)
self
.
assertEqual
(
len
(
req
.
input_embeds
),
2
)
# Both should be the same single embedding
self
.
assertEqual
(
req
.
input_embeds
[
0
],
[[
0.1
,
0.2
,
0.3
]])
self
.
assertEqual
(
req
.
input_embeds
[
1
],
[[
0.1
,
0.2
,
0.3
]])
# Test with higher n value
req
=
GenerateReqInput
(
input_embeds
=
[[
0.1
,
0.2
,
0.3
]],
sampling_params
=
{
"n"
:
5
})
req
.
normalize_batch_and_arguments
()
self
.
assertFalse
(
req
.
is_single
)
self
.
assertEqual
(
len
(
req
.
input_embeds
),
5
)
# All should be the same
for
i
in
range
(
5
):
self
.
assertEqual
(
req
.
input_embeds
[
i
],
[[
0.1
,
0.2
,
0.3
]])
def
test_lora_path_normalization
(
self
):
def
test_lora_path_normalization
(
self
):
"""Test normalization of lora_path."""
"""Test normalization of lora_path."""
# Test single lora_path with batch input
# Test single lora_path with batch input
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment