Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e0e09fce
Unverified
Commit
e0e09fce
authored
Dec 29, 2024
by
Ying Sheng
Committed by
GitHub
Dec 29, 2024
Browse files
[Session] Update session control interface (#2635)
parent
9c05c689
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
530 additions
and
90 deletions
+530
-90
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+15
-8
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+23
-10
python/sglang/srt/managers/session_controller.py
python/sglang/srt/managers/session_controller.py
+102
-27
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+15
-10
python/sglang/srt/server.py
python/sglang/srt/server.py
+4
-0
test/srt/test_session_control.py
test/srt/test_session_control.py
+371
-35
No files found.
python/sglang/srt/managers/io_struct.py
View file @
e0e09fce
...
...
@@ -27,6 +27,14 @@ from sglang.srt.managers.schedule_batch import BaseFinishReason
from
sglang.srt.sampling.sampling_params
import
SamplingParams
@
dataclass
class
SessionParams
:
id
:
Optional
[
str
]
=
None
rid
:
Optional
[
str
]
=
None
offset
:
Optional
[
int
]
=
None
replace
:
Optional
[
bool
]
=
None
@
dataclass
class
GenerateReqInput
:
# The input prompt. It can be a single prompt or a batch of prompts.
...
...
@@ -58,10 +66,8 @@ class GenerateReqInput:
# LoRA related
lora_path
:
Optional
[
Union
[
List
[
Optional
[
str
]],
Optional
[
str
]]]
=
None
# Session id info for continual prompting
session
:
Optional
[
Union
[
List
[
Tuple
[
str
,
Optional
[
str
]]],
Tuple
[
str
,
Optional
[
str
]]]
]
=
None
# Session info for continual prompting
session_params
:
Optional
[
Union
[
List
[
Dict
],
Dict
]]
=
None
def
normalize_batch_and_arguments
(
self
):
if
(
...
...
@@ -223,9 +229,8 @@ class TokenizedGenerateReqInput:
# The input embeds
input_embeds
:
Optional
[
Union
[
List
[
List
[
List
[
float
]]],
List
[
List
[
float
]]]]
=
None
# Session id info for continual prompting
session_id
:
Optional
[
str
]
=
None
session_rid
:
Optional
[
str
]
=
None
# Session info for continual prompting
session_params
:
Optional
[
SessionParams
]
=
None
@
dataclass
...
...
@@ -468,6 +473,7 @@ class ProfileReq(Enum):
@
dataclass
class
OpenSessionReqInput
:
capacity_of_str_len
:
int
session_id
:
Optional
[
str
]
=
None
@
dataclass
...
...
@@ -477,4 +483,5 @@ class CloseSessionReqInput:
@
dataclass
class
OpenSessionReqOutput
:
session_id
:
str
session_id
:
Optional
[
str
]
success
:
bool
python/sglang/srt/managers/scheduler.py
View file @
e0e09fce
...
...
@@ -22,7 +22,7 @@ import warnings
from
collections
import
deque
from
concurrent
import
futures
from
types
import
SimpleNamespace
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
psutil
import
setproctitle
...
...
@@ -498,8 +498,10 @@ class Scheduler:
else
:
self
.
stop_profile
()
elif
isinstance
(
recv_req
,
OpenSessionReqInput
):
session_id
=
self
.
open_session
(
recv_req
)
self
.
send_to_tokenizer
.
send_pyobj
(
OpenSessionReqOutput
(
session_id
))
session_id
,
success
=
self
.
open_session
(
recv_req
)
self
.
send_to_tokenizer
.
send_pyobj
(
OpenSessionReqOutput
(
session_id
=
session_id
,
success
=
success
)
)
elif
isinstance
(
recv_req
,
CloseSessionReqInput
):
self
.
close_session
(
recv_req
)
else
:
...
...
@@ -510,7 +512,11 @@ class Scheduler:
recv_req
:
TokenizedGenerateReqInput
,
):
# Create a new request
if
recv_req
.
session_id
is
None
or
recv_req
.
session_id
not
in
self
.
sessions
:
if
(
recv_req
.
session_params
is
None
or
recv_req
.
session_params
.
id
is
None
or
recv_req
.
session_params
.
id
not
in
self
.
sessions
):
if
recv_req
.
input_embeds
is
not
None
:
# Generate fake input_ids based on the length of input_embeds
...
...
@@ -532,15 +538,18 @@ class Scheduler:
)
req
.
tokenizer
=
self
.
tokenizer
if
recv_req
.
session_id
is
not
None
:
if
(
recv_req
.
session_params
is
not
None
and
recv_req
.
session_params
.
id
is
not
None
):
req
.
finished_reason
=
FINISH_ABORT
(
f
"Invalid request: session id
{
recv_req
.
session_id
}
does not exist"
f
"Invalid request: session id
{
recv_req
.
session_
params
.
id
}
does not exist"
)
self
.
waiting_queue
.
append
(
req
)
return
else
:
# Create a new request from a previ
s
ou session
session
=
self
.
sessions
[
recv_req
.
session_id
]
# Create a new request from a previou
s
session
session
=
self
.
sessions
[
recv_req
.
session_
params
.
id
]
req
=
session
.
create_req
(
recv_req
,
self
.
tokenizer
)
if
isinstance
(
req
.
finished_reason
,
FINISH_ABORT
):
self
.
waiting_queue
.
append
(
req
)
...
...
@@ -1500,16 +1509,20 @@ class Scheduler:
)
logger
.
info
(
"Profiler is done"
)
def
open_session
(
self
,
recv_req
:
OpenSessionReqInput
)
->
str
:
def
open_session
(
self
,
recv_req
:
OpenSessionReqInput
)
->
Tuple
[
Optional
[
str
],
bool
]
:
# handle error
session_id
=
recv_req
.
session_id
if
session_id
in
self
.
sessions
:
logger
.
warning
(
f
"session id
{
session_id
}
already exist, cannot open."
)
return
session_id
,
False
elif
session_id
is
None
:
logger
.
warning
(
f
"session id is None, cannot open."
)
return
session_id
,
False
else
:
self
.
sessions
[
session_id
]
=
Session
(
recv_req
.
capacity_of_str_len
,
session_id
)
return
session_id
return
session_id
,
True
def
close_session
(
self
,
recv_req
:
CloseSessionReqInput
):
# handle error
...
...
python/sglang/srt/managers/session_controller.py
View file @
e0e09fce
...
...
@@ -10,41 +10,116 @@
# limitations under the License.
# ==============================================================================
import
logging
import
uuid
from
typing
import
Dict
,
Optional
from
sglang.srt.managers.io_struct
import
TokenizedGenerateReqInput
from
sglang.srt.managers.schedule_batch
import
FINISH_ABORT
,
List
,
Req
from
sglang.srt.managers.schedule_batch
import
Req
class
SessionReqNode
:
def
__init__
(
self
,
req
,
parent
=
None
,
childs
=
None
):
self
.
req
=
req
self
.
parent
=
parent
if
parent
is
not
None
:
parent
.
childs
.
append
(
self
)
self
.
childs
=
[]
if
not
childs
else
childs
def
clear_childs
(
self
,
req_dict
):
for
req_node
in
self
.
childs
:
req_node
.
clear
(
req_dict
)
self
.
childs
=
[]
def
clear
(
self
,
req_dict
):
for
req_node
in
self
.
childs
:
req_node
.
clear
(
req_dict
)
if
self
.
req
.
finished_reason
==
None
:
self
.
req
.
to_abort
=
True
del
req_dict
[
self
.
req
.
rid
]
def
abort
(
self
):
if
self
.
req
.
finished_reason
==
None
:
self
.
req
.
to_abort
=
True
def
__str__
(
self
):
return
self
.
_str_helper
(
self
.
req
.
rid
)
def
_str_helper
(
self
,
prefix
=
""
):
if
len
(
self
.
childs
)
==
0
:
return
prefix
+
"
\n
"
else
:
origin_prefix
=
prefix
prefix
+=
" -- "
+
self
.
childs
[
0
].
req
.
rid
ret
=
self
.
childs
[
0
].
_str_helper
(
prefix
)
for
child
in
self
.
childs
[
1
:]:
prefix
=
" "
*
len
(
origin_prefix
)
+
" \- "
+
child
.
req
.
rid
ret
+=
child
.
_str_helper
(
prefix
)
return
ret
class
Session
:
def
__init__
(
self
,
capacity_of_str_len
:
int
,
session_id
:
str
=
None
):
def
__init__
(
self
,
capacity_of_str_len
:
int
,
session_id
:
Optional
[
str
]
=
None
):
self
.
session_id
=
session_id
if
session_id
is
not
None
else
uuid
.
uuid4
().
hex
self
.
capacity_of_str_len
=
capacity_of_str_len
self
.
req
s
:
List
[
Req
]
=
[]
self
.
req
_nodes
:
Dict
[
str
,
SessionReqNode
]
=
{}
def
create_req
(
self
,
req
:
TokenizedGenerateReqInput
,
tokenizer
):
if
req
.
session_rid
is
not
None
:
while
len
(
self
.
reqs
)
>
0
:
if
self
.
reqs
[
-
1
].
rid
==
req
.
session_rid
:
break
self
.
reqs
=
self
.
reqs
[:
-
1
]
assert
req
.
session_params
is
not
None
session_params
=
req
.
session_params
last_req_node
=
None
last_req
=
None
abort
=
False
if
session_params
.
replace
:
if
session_params
.
rid
is
None
:
for
_
,
req_node
in
self
.
req_nodes
.
items
():
req_node
.
clear
(
self
.
req_nodes
)
else
:
if
session_params
.
rid
not
in
self
.
req_nodes
:
abort
=
True
else
:
last_req_node
=
self
.
req_nodes
[
session_params
.
rid
]
last_req_node
.
abort
()
last_req
=
last_req_node
.
req
last_req_node
.
clear_childs
(
self
.
req_nodes
)
else
:
self
.
reqs
=
[]
if
len
(
self
.
reqs
)
>
0
:
if
session_params
.
rid
is
not
None
:
if
session_params
.
rid
not
in
self
.
req_nodes
:
abort
=
True
else
:
last_req_node
=
self
.
req_nodes
[
session_params
.
rid
]
last_req
=
last_req_node
.
req
if
not
last_req
.
finished
():
logging
.
warning
(
"The request in a session is appending to a request that hasn't finished."
)
abort
=
True
if
last_req
is
not
None
:
# trim bos token if it is an append
if
req
.
input_ids
[
0
]
==
tokenizer
.
bos_token_id
:
req
.
input_ids
=
req
.
input_ids
[
1
:]
input_ids
=
(
self
.
reqs
[
-
1
].
origin_input_ids
+
self
.
reqs
[
-
1
].
output_ids
[
:
self
.
reqs
[
-
1
].
sampling_params
.
max_new_tokens
]
+
req
.
input_ids
last_req
.
origin_input_ids
+
last_req
.
output_ids
[:
last_req
.
sampling_params
.
max_new_tokens
]
)
if
session_params
.
offset
and
session_params
.
offset
!=
0
:
input_ids
=
input_ids
[:
session_params
.
offset
]
+
req
.
input_ids
else
:
input_ids
+=
req
.
input_ids
input_ids_unpadded
=
(
self
.
reqs
[
-
1
].
origin_input_ids_unpadded
+
self
.
reqs
[
-
1
].
output_ids
[
:
self
.
reqs
[
-
1
].
sampling_params
.
max_new_tokens
]
+
req
.
input_ids
last_req
.
origin_input_ids_unpadded
+
last_req
.
output_ids
[:
last_req
.
sampling_params
.
max_new_tokens
]
)
if
session_params
.
offset
and
session_params
.
offset
!=
0
:
input_ids_unpadded
=
(
input_ids_unpadded
[:
session_params
.
offset
]
+
req
.
input_ids
)
else
:
input_ids_unpadded
+=
req
.
input_ids
else
:
input_ids
=
req
.
input_ids
input_ids_unpadded
=
req
.
input_ids
...
...
@@ -57,13 +132,13 @@ class Session:
lora_path
=
req
.
lora_path
,
session_id
=
self
.
session_id
,
)
if
l
en
(
self
.
reqs
)
>
0
:
new_req
.
image_inputs
=
self
.
reqs
[
-
1
]
.
image_inputs
if
l
ast_req
is
not
None
:
new_req
.
image_inputs
=
last_req
.
image_inputs
new_req
.
tokenizer
=
tokenizer
if
req
.
session_rid
is
not
None
and
len
(
self
.
reqs
)
==
0
:
new_req
.
finished_reason
=
FINISH_ABORT
(
f
"Invalid request: requested session rid
{
req
.
session_rid
}
does not exist in the session history"
)
if
abort
:
new_req
.
to_abort
=
True
else
:
self
.
reqs
.
append
(
new_req
)
new_req_node
=
SessionReqNode
(
new_req
,
last_req_node
)
self
.
req_nodes
[
req
.
rid
]
=
new_req_node
return
new_req
python/sglang/srt/managers/tokenizer_manager.py
View file @
e0e09fce
...
...
@@ -53,6 +53,7 @@ from sglang.srt.managers.io_struct import (
OpenSessionReqInput
,
OpenSessionReqOutput
,
ProfileReq
,
SessionParams
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
UpdateWeightFromDiskReqInput
,
...
...
@@ -264,8 +265,9 @@ class TokenizerManager:
return_logprob
=
obj
.
return_logprob
logprob_start_len
=
obj
.
logprob_start_len
top_logprobs_num
=
obj
.
top_logprobs_num
session_id
=
obj
.
session
[
0
]
if
obj
.
session
else
None
session_rid
=
obj
.
session
[
1
]
if
obj
.
session
else
None
session_params
=
(
SessionParams
(
**
obj
.
session_params
)
if
obj
.
session_params
else
None
)
if
obj
.
input_ids
is
not
None
and
len
(
input_ids
)
>=
self
.
context_len
:
raise
ValueError
(
...
...
@@ -292,8 +294,7 @@ class TokenizerManager:
obj
.
stream
,
lora_path
=
obj
.
lora_path
,
input_embeds
=
input_embeds
,
session_id
=
session_id
,
session_rid
=
session_rid
,
session_params
=
session_params
,
)
elif
isinstance
(
obj
,
EmbeddingReqInput
):
tokenized_obj
=
TokenizedEmbeddingReqInput
(
...
...
@@ -552,12 +553,16 @@ class TokenizerManager:
):
self
.
auto_create_handle_loop
()
session_id
=
uuid
.
uuid4
().
hex
obj
.
session_id
=
session_id
if
obj
.
session_id
is
None
:
obj
.
session_id
=
uuid
.
uuid4
().
hex
elif
obj
.
session_id
in
self
.
session_futures
:
return
None
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
self
.
session_futures
[
session_id
]
=
asyncio
.
Future
()
session_id
=
await
self
.
session_futures
[
session_id
]
del
self
.
session_futures
[
session_id
]
self
.
session_futures
[
obj
.
session_id
]
=
asyncio
.
Future
()
session_id
=
await
self
.
session_futures
[
obj
.
session_id
]
del
self
.
session_futures
[
obj
.
session_id
]
return
session_id
async
def
close_session
(
...
...
@@ -709,7 +714,7 @@ class TokenizerManager:
)
elif
isinstance
(
recv_obj
,
OpenSessionReqOutput
):
self
.
session_futures
[
recv_obj
.
session_id
].
set_result
(
recv_obj
.
session_id
recv_obj
.
session_id
if
recv_obj
.
success
else
None
)
elif
isinstance
(
recv_obj
,
UpdateWeightFromDiskReqOutput
):
if
self
.
server_args
.
dp_size
==
1
:
...
...
python/sglang/srt/server.py
View file @
e0e09fce
...
...
@@ -259,6 +259,10 @@ async def open_session(obj: OpenSessionReqInput, request: Request):
"""Open a session, and return its unique session id."""
try
:
session_id
=
await
tokenizer_manager
.
open_session
(
obj
,
request
)
if
session_id
is
None
:
raise
Exception
(
"Failed to open the session. Check if a session with the same id is still open."
)
return
session_id
except
Exception
as
e
:
return
_create_error_response
(
e
)
...
...
test/srt/test_session_control.py
View file @
e0e09fce
"""
Usage:
python3 -m unittest test_session_control.TestSessionControl.test_session_control
python3 -m unittest test_session_control.TestSessionControl.test_session_control_with_branching
python3 -m unittest test_session_control.TestSessionControl.test_session_control_backtrack_with_abort
python3 -m unittest test_session_control.TestSessionControlVision.test_session_control
"""
import
asyncio
import
json
import
unittest
import
aiohttp
import
requests
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
...
...
@@ -18,6 +23,10 @@ from sglang.test.test_utils import (
)
def
remove_prefix
(
text
:
str
,
prefix
:
str
)
->
str
:
return
text
[
len
(
prefix
)
:]
if
text
.
startswith
(
prefix
)
else
text
class
TestSessionControl
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
...
...
@@ -31,15 +40,18 @@ class TestSessionControl(unittest.TestCase):
def
tearDownClass
(
cls
):
kill_process_tree
(
cls
.
process
.
pid
)
def
test_session_control
(
self
):
def
test_session_control
(
self
,
gen_len
=
12
):
chunks
=
[
"Let me tell you something about France."
,
"The capital of France is"
,
"The population of the city is"
,
"A brief history about that city is"
,
"To plan a travel, the budget is"
,
]
tokenizer
=
get_tokenizer
(
self
.
model
)
chunks_ids
=
[
tokenizer
.
encode
(
x
)
for
x
in
chunks
]
for
i
in
range
(
1
,
len
(
chunks_ids
)):
if
chunks_ids
[
i
][
0
]
==
tokenizer
.
bos_token_id
:
chunks_ids
[
i
]
=
chunks_ids
[
i
][
1
:]
# 1. using session control
session_id
=
requests
.
post
(
...
...
@@ -48,6 +60,13 @@ class TestSessionControl(unittest.TestCase):
).
json
()
rid
=
None
# open an existing session, should get session_id as None
response
=
requests
.
post
(
self
.
base_url
+
"/open_session"
,
json
=
{
"capacity_of_str_len"
:
1000
,
"session_id"
:
session_id
},
).
json
()
assert
isinstance
(
response
,
dict
)
and
"error"
in
response
first_rid
=
None
outputs_from_session
=
[]
for
i
,
chunk_ids
in
enumerate
(
chunks_ids
):
...
...
@@ -55,11 +74,16 @@ class TestSessionControl(unittest.TestCase):
self
.
base_url
+
"/generate"
,
json
=
{
"input_ids"
:
chunk_ids
,
"session"
:
[
session_id
,
rid
],
"session_params"
:
{
"id"
:
session_id
,
"rid"
:
rid
,
"offset"
:
-
1
,
"replace"
:
True
,
},
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
(
16
if
i
>
0
else
0
gen_len
if
i
>
0
else
1
),
# prefill only for the first chunk
"no_stop_trim"
:
True
,
"skip_special_tokens"
:
False
,
...
...
@@ -77,10 +101,15 @@ class TestSessionControl(unittest.TestCase):
self
.
base_url
+
"/generate"
,
json
=
{
"input_ids"
:
chunks_ids
[
-
1
],
"session"
:
[
session_id
,
first_rid
],
"session_params"
:
{
"id"
:
session_id
,
"rid"
:
first_rid
,
"offset"
:
-
1
,
"replace"
:
True
,
},
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
16
,
"max_new_tokens"
:
gen_len
,
"no_stop_trim"
:
True
,
"skip_special_tokens"
:
False
,
},
...
...
@@ -93,10 +122,15 @@ class TestSessionControl(unittest.TestCase):
self
.
base_url
+
"/generate"
,
json
=
{
"input_ids"
:
chunks_ids
[
-
1
],
"session"
:
[
session_id
,
rid
],
"session_params"
:
{
"id"
:
session_id
,
"rid"
:
rid
,
"offset"
:
-
1
,
"replace"
:
True
,
},
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
16
,
"max_new_tokens"
:
gen_len
,
"no_stop_trim"
:
True
,
"skip_special_tokens"
:
False
,
},
...
...
@@ -115,10 +149,15 @@ class TestSessionControl(unittest.TestCase):
self
.
base_url
+
"/generate"
,
json
=
{
"input_ids"
:
chunks_ids
[
-
1
],
"session"
:
[
session_id
,
first_rid
],
"session_params"
:
{
"id"
:
session_id
,
"rid"
:
first_rid
,
"offset"
:
-
1
,
"replace"
:
True
,
},
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
16
,
"max_new_tokens"
:
gen_len
,
"no_stop_trim"
:
True
,
"skip_special_tokens"
:
False
,
},
...
...
@@ -127,6 +166,8 @@ class TestSessionControl(unittest.TestCase):
assert
response
[
"meta_info"
][
"finish_reason"
][
"type"
]
==
"abort"
# 2. not use session control
requests
.
post
(
self
.
base_url
+
"/flush_cache"
)
input_ids_first_req
=
None
input_ids
=
[]
outputs_normal
=
[]
...
...
@@ -139,7 +180,7 @@ class TestSessionControl(unittest.TestCase):
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
(
16
if
i
>
0
else
0
gen_len
if
i
>
0
else
1
),
# prefill only for the first chunk
"no_stop_trim"
:
True
,
"skip_special_tokens"
:
False
,
...
...
@@ -150,7 +191,7 @@ class TestSessionControl(unittest.TestCase):
output_ids
=
tokenizer
.
encode
(
response
[
"text"
])
if
output_ids
[
0
]
==
tokenizer
.
bos_token_id
:
output_ids
=
output_ids
[
1
:]
input_ids
+=
output_ids
input_ids
+=
output_ids
[:
-
1
]
outputs_normal
.
append
(
response
[
"text"
])
if
i
==
0
:
input_ids_first_req
=
input_ids
.
copy
()
...
...
@@ -162,7 +203,7 @@ class TestSessionControl(unittest.TestCase):
"input_ids"
:
input_ids_first_req
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
16
,
"max_new_tokens"
:
gen_len
,
"no_stop_trim"
:
True
,
"skip_special_tokens"
:
False
,
},
...
...
@@ -176,6 +217,272 @@ class TestSessionControl(unittest.TestCase):
print
(
outputs_normal
)
assert
outputs_from_session
==
outputs_normal
async
def
async_generate
(
self
,
payload
):
url
=
self
.
base_url
+
"/generate"
async
with
aiohttp
.
ClientSession
()
as
session
:
async
with
session
.
post
(
url
=
url
,
json
=
payload
)
as
response
:
assert
response
.
status
==
200
async
for
chunk_bytes
in
response
.
content
:
chunk_bytes
=
chunk_bytes
.
strip
()
if
not
chunk_bytes
:
continue
chunk
=
remove_prefix
(
chunk_bytes
.
decode
(
"utf-8"
),
"data: "
)
if
chunk
==
"[DONE]"
:
yield
""
,
None
,
""
else
:
data
=
json
.
loads
(
chunk
)
finish_reason
=
(
data
[
"meta_info"
][
"finish_reason"
][
"type"
]
if
data
[
"meta_info"
][
"finish_reason"
]
else
""
)
yield
data
[
"text"
],
data
[
"meta_info"
][
"id"
],
finish_reason
async
def
run_session_control_backtrack_with_abort
(
self
,
replace
):
chunks
=
[
"Let me tell you something about France."
,
"The capital of France is"
,
]
tokenizer
=
get_tokenizer
(
self
.
model
)
chunks_ids
=
[
tokenizer
.
encode
(
x
)
for
x
in
chunks
]
for
i
in
range
(
1
,
len
(
chunks_ids
)):
if
chunks_ids
[
i
][
0
]
==
tokenizer
.
bos_token_id
:
chunks_ids
[
i
]
=
chunks_ids
[
i
][
1
:]
# 1. using session control
session_id
=
requests
.
post
(
self
.
base_url
+
"/open_session"
,
json
=
{
"capacity_of_str_len"
:
1000
},
).
json
()
rid
=
None
payload
=
{
"input_ids"
:
chunks_ids
[
0
],
"session_params"
:
{
"id"
:
session_id
,
"rid"
:
rid
,
"offset"
:
-
1
,
"replace"
:
True
,
},
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
100
,
"no_stop_trim"
:
True
,
"skip_special_tokens"
:
False
,
"ignore_eos"
:
True
,
},
"stream"
:
True
,
}
gen_so_far
=
""
finish_reason
=
""
second_output
=
""
async
for
chunk
,
rid
,
finish_reason_chunk
in
self
.
async_generate
(
payload
):
gen_so_far
+=
chunk
if
finish_reason
==
""
:
finish_reason
=
finish_reason_chunk
if
len
(
gen_so_far
)
>
50
and
second_output
==
""
:
payload2
=
{
"input_ids"
:
chunks_ids
[
1
],
"session_params"
:
{
"id"
:
session_id
,
"rid"
:
rid
,
"offset"
:
50
,
"replace"
:
replace
,
},
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
32
,
"no_stop_trim"
:
True
,
"skip_special_tokens"
:
False
,
},
"stream"
:
False
,
"stream_output"
:
True
,
}
response
=
requests
.
post
(
url
=
self
.
base_url
+
"/generate"
,
json
=
payload2
).
json
()
second_output
=
response
[
"text"
]
if
replace
:
assert
finish_reason
==
"abort"
print
(
"first request output:"
)
print
(
gen_so_far
)
print
(
"second request output:"
)
print
(
second_output
)
# close the session
ret
=
requests
.
post
(
self
.
base_url
+
"/close_session"
,
json
=
{
"session_id"
:
session_id
},
)
assert
ret
.
status_code
==
200
if
not
replace
:
assert
response
[
"meta_info"
][
"finish_reason"
][
"type"
]
==
"abort"
else
:
# 2. not using session control
output_ids
=
tokenizer
.
encode
(
gen_so_far
)
if
output_ids
[
0
]
==
tokenizer
.
bos_token_id
:
output_ids
=
output_ids
[
1
:]
input_ids
=
chunks_ids
[
0
]
+
output_ids
input_ids
=
input_ids
[:
50
]
+
chunks_ids
[
1
]
payload
=
{
"input_ids"
:
input_ids
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
32
,
"no_stop_trim"
:
True
,
"skip_special_tokens"
:
False
,
},
"stream"
:
False
,
"stream_output"
:
True
,
}
response
=
requests
.
post
(
url
=
self
.
base_url
+
"/generate"
,
json
=
payload
).
json
()
output_no_session
=
response
[
"text"
]
print
(
"second request output without session:"
)
print
(
output_no_session
)
assert
second_output
==
output_no_session
def
test_session_control_backtrack_with_abort
(
self
):
asyncio
.
run
(
self
.
run_session_control_backtrack_with_abort
(
replace
=
True
))
asyncio
.
run
(
self
.
run_session_control_backtrack_with_abort
(
replace
=
False
))
def
run_session_control_with_branching
(
self
,
root_prompt
,
chunks_per_step
,
gen_len
=
16
):
for
x
in
chunks_per_step
:
assert
len
(
x
)
==
len
(
chunks_per_step
[
0
])
# 1. using session control
session_id
=
requests
.
post
(
self
.
base_url
+
"/open_session"
,
json
=
{
"capacity_of_str_len"
:
1000
},
).
json
()
outputs_from_session
=
[]
# send the root prompt
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
root_prompt
,
"session_params"
:
{
"id"
:
session_id
,
"rid"
:
None
,
"offset"
:
0
,
"replace"
:
False
,
},
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
gen_len
,
"no_stop_trim"
:
True
,
"skip_special_tokens"
:
False
,
},
},
).
json
()
rid_per_branch
=
[
response
[
"meta_info"
][
"id"
]]
*
len
(
chunks_per_step
[
0
])
outputs_from_session
.
append
(
response
[
"text"
])
# send the prompts in branches
for
chunks_for_branches
in
chunks_per_step
:
for
j
,
chunk
in
enumerate
(
chunks_for_branches
):
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
chunk
,
"session_params"
:
{
"id"
:
session_id
,
"rid"
:
rid_per_branch
[
j
],
"offset"
:
0
,
"replace"
:
False
,
},
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
gen_len
,
"no_stop_trim"
:
True
,
"skip_special_tokens"
:
False
,
},
},
).
json
()
rid
=
response
[
"meta_info"
][
"id"
]
rid_per_branch
[
j
]
=
rid
outputs_from_session
.
append
(
response
[
"text"
])
# close the session
ret
=
requests
.
post
(
self
.
base_url
+
"/close_session"
,
json
=
{
"session_id"
:
session_id
},
)
assert
ret
.
status_code
==
200
# 2. not use session control
requests
.
post
(
self
.
base_url
+
"/flush_cache"
)
outputs_normal
=
[]
input_texts
=
[
root_prompt
]
*
len
(
chunks_per_step
[
0
])
# send the root prompt
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
root_prompt
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
gen_len
,
"no_stop_trim"
:
True
,
"skip_special_tokens"
:
False
,
},
},
).
json
()
outputs_normal
.
append
(
response
[
"text"
])
input_texts
=
[
x
+
response
[
"text"
]
for
x
in
input_texts
]
# send the prompts in branches
for
chunks_for_branches
in
chunks_per_step
:
for
j
,
chunk
in
enumerate
(
chunks_for_branches
):
input_texts
[
j
]
+=
chunk
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"text"
:
input_texts
[
j
],
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
gen_len
,
"no_stop_trim"
:
True
,
"skip_special_tokens"
:
False
,
},
},
).
json
()
outputs_normal
.
append
(
response
[
"text"
])
input_texts
[
j
]
+=
response
[
"text"
]
print
(
"====== outputs from chunked queries with session control: ======="
)
print
(
outputs_from_session
)
print
(
"====== outputs from normal queries: ======="
)
print
(
outputs_normal
)
assert
outputs_from_session
==
outputs_normal
def
test_session_control_with_branching
(
self
):
root_prompt
=
"First, let me explain in one sentence about AI"
chunks_per_step
=
[
[
"Then, briefly, the positive side of AI is"
,
"But, briefly, AI could be harmful to human"
,
],
[
"For example"
,
"For example"
],
]
self
.
run_session_control_with_branching
(
root_prompt
=
root_prompt
,
chunks_per_step
=
chunks_per_step
,
gen_len
=
8
)
root_prompt
=
"I have three apples."
chunks_per_step
=
[
[
"I then give one apple to my friend"
,
"My friend give me another apple."
],
[
"I still have"
,
"I now have"
],
]
self
.
run_session_control_with_branching
(
root_prompt
=
root_prompt
,
chunks_per_step
=
chunks_per_step
,
gen_len
=
8
)
class
TestSessionControlVision
(
unittest
.
TestCase
):
@
classmethod
...
...
@@ -197,17 +504,25 @@ class TestSessionControlVision(unittest.TestCase):
text_chunks
=
[
"<|im_start|>system
\n
You are a helpful assistant.<|im_end|>
\n
"
,
"<|im_start|>user
\n
<image>
\n
Describe this image in a very short sentence.<|im_end|>
\n
<|im_start|>assistant
\n
"
,
"<|im_start|>user
\n
<image>
\n
Is this image same with the previous image? Answer yes or no.<|im_end|>
\n
<|im_start|>assistant
\n
"
,
"<|im_start|>user
\n
<image>
\n
Is this image same with the previous image? Answer yes or no.<|im_end|>
\n
<|im_start|>assistant
\n
"
,
"<|im_start|>user
\n
<image>
\n
Is this image same with one of the previous images?<|im_end|>
\n
<|im_start|>assistant
\n
"
,
"<|im_start|>user
\n
<image>
\n
Is this image same with one of the previous images?<|im_end|>
\n
<|im_start|>assistant
\n
"
,
"<|im_start|>user
\n
Describe this image in a very short sentence.<|im_end|>
\n
assistant:"
,
]
image_chunks
=
[
"https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
,
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
,
"https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png"
,
"https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
,
]
assert
len
(
text_chunks
)
==
len
(
image_chunks
)
+
1
assert
(
len
(
text_chunks
)
==
len
(
image_chunks
)
+
2
)
# the first and the last prompt does not contain images
tokenizer
=
get_tokenizer
(
self
.
model
)
text_input_ids
=
[
tokenizer
.
encode
(
x
)
for
x
in
text_chunks
]
for
i
in
range
(
1
,
len
(
text_input_ids
)):
if
text_input_ids
[
i
][
0
]
==
tokenizer
.
bos_token_id
:
text_input_ids
[
i
]
=
text_input_ids
[
i
][
1
:]
gen_len
=
32
# 1. using session control
session_id
=
requests
.
post
(
...
...
@@ -216,20 +531,32 @@ class TestSessionControlVision(unittest.TestCase):
).
json
()
rid
=
None
# open an existing session, should get session_id as None
response
=
requests
.
post
(
self
.
base_url
+
"/open_session"
,
json
=
{
"capacity_of_str_len"
:
1000
,
"session_id"
:
session_id
},
).
json
()
assert
isinstance
(
response
,
dict
)
and
"error"
in
response
first_rid
=
None
outputs_from_session
=
[]
for
i
in
range
(
len
(
text_input_ids
)):
for
i
in
range
(
len
(
text_input_ids
[:
-
1
]
)):
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"input_ids"
:
text_input_ids
[
i
],
"image_data"
:
image_chunks
[
i
-
1
]
if
i
>
0
else
None
,
"modalities"
:
[
"multi-images"
],
"session"
:
[
session_id
,
rid
],
"session_params"
:
{
"id"
:
session_id
,
"rid"
:
rid
,
"offset"
:
0
,
"replace"
:
True
,
},
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
(
16
if
i
>
0
else
0
gen_len
if
i
>
0
else
0
),
# prefill only for the first chunk
"no_stop_trim"
:
True
,
"skip_special_tokens"
:
False
,
...
...
@@ -247,12 +574,15 @@ class TestSessionControlVision(unittest.TestCase):
self
.
base_url
+
"/generate"
,
json
=
{
"input_ids"
:
text_input_ids
[
-
1
],
"image_data"
:
image_chunks
[
-
1
:],
"modalities"
:
[
"multi-images"
],
"session"
:
[
session_id
,
first_rid
],
"session_params"
:
{
"id"
:
session_id
,
"rid"
:
first_rid
,
"offset"
:
0
,
"replace"
:
True
,
},
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
16
,
"max_new_tokens"
:
gen_len
,
"no_stop_trim"
:
True
,
"skip_special_tokens"
:
False
,
},
...
...
@@ -265,12 +595,15 @@ class TestSessionControlVision(unittest.TestCase):
self
.
base_url
+
"/generate"
,
json
=
{
"input_ids"
:
text_input_ids
[
-
1
],
"image_data"
:
image_chunks
[
-
1
:],
"modalities"
:
[
"multi-images"
],
"session"
:
[
session_id
,
rid
],
"session_params"
:
{
"id"
:
session_id
,
"rid"
:
rid
,
"offset"
:
0
,
"replace"
:
True
,
},
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
16
,
"max_new_tokens"
:
gen_len
,
"no_stop_trim"
:
True
,
"skip_special_tokens"
:
False
,
},
...
...
@@ -289,10 +622,15 @@ class TestSessionControlVision(unittest.TestCase):
self
.
base_url
+
"/generate"
,
json
=
{
"input_ids"
:
text_input_ids
[
-
1
],
"session"
:
[
session_id
,
first_rid
],
"session_params"
:
{
"id"
:
session_id
,
"rid"
:
first_rid
,
"offset"
:
0
,
"replace"
:
True
,
},
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
16
,
"max_new_tokens"
:
gen_len
,
"no_stop_trim"
:
True
,
"skip_special_tokens"
:
False
,
},
...
...
@@ -306,7 +644,7 @@ class TestSessionControlVision(unittest.TestCase):
input_ids_first_req
=
None
input_ids
=
[]
outputs_normal
=
[]
for
i
in
range
(
len
(
text_input_ids
)):
for
i
in
range
(
len
(
text_input_ids
[:
-
1
]
)):
input_ids
+=
text_input_ids
[
i
]
image_data
=
image_chunks
[:
i
]
if
i
>
0
else
None
response
=
requests
.
post
(
...
...
@@ -318,7 +656,7 @@ class TestSessionControlVision(unittest.TestCase):
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
(
16
if
i
>
0
else
0
gen_len
if
i
>
0
else
0
),
# prefill only for the first chunk
"no_stop_trim"
:
True
,
"skip_special_tokens"
:
False
,
...
...
@@ -339,11 +677,9 @@ class TestSessionControlVision(unittest.TestCase):
self
.
base_url
+
"/generate"
,
json
=
{
"input_ids"
:
input_ids_first_req
,
"image_data"
:
image_chunks
[
-
1
:],
"modalities"
:
[
"multi-images"
],
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
16
,
"max_new_tokens"
:
gen_len
,
"no_stop_trim"
:
True
,
"skip_special_tokens"
:
False
,
},
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment