Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e0e09fce
Unverified
Commit
e0e09fce
authored
Dec 29, 2024
by
Ying Sheng
Committed by
GitHub
Dec 29, 2024
Browse files
[Session] Update session control interface (#2635)
parent
9c05c689
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
530 additions
and
90 deletions
+530
-90
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+15
-8
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+23
-10
python/sglang/srt/managers/session_controller.py
python/sglang/srt/managers/session_controller.py
+102
-27
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+15
-10
python/sglang/srt/server.py
python/sglang/srt/server.py
+4
-0
test/srt/test_session_control.py
test/srt/test_session_control.py
+371
-35
No files found.
python/sglang/srt/managers/io_struct.py
View file @
e0e09fce
...
@@ -27,6 +27,14 @@ from sglang.srt.managers.schedule_batch import BaseFinishReason
...
@@ -27,6 +27,14 @@ from sglang.srt.managers.schedule_batch import BaseFinishReason
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.sampling.sampling_params
import
SamplingParams
@
dataclass
class
SessionParams
:
id
:
Optional
[
str
]
=
None
rid
:
Optional
[
str
]
=
None
offset
:
Optional
[
int
]
=
None
replace
:
Optional
[
bool
]
=
None
@
dataclass
@
dataclass
class
GenerateReqInput
:
class
GenerateReqInput
:
# The input prompt. It can be a single prompt or a batch of prompts.
# The input prompt. It can be a single prompt or a batch of prompts.
...
@@ -58,10 +66,8 @@ class GenerateReqInput:
...
@@ -58,10 +66,8 @@ class GenerateReqInput:
# LoRA related
# LoRA related
lora_path
:
Optional
[
Union
[
List
[
Optional
[
str
]],
Optional
[
str
]]]
=
None
lora_path
:
Optional
[
Union
[
List
[
Optional
[
str
]],
Optional
[
str
]]]
=
None
# Session id info for continual prompting
# Session info for continual prompting
session
:
Optional
[
session_params
:
Optional
[
Union
[
List
[
Dict
],
Dict
]]
=
None
Union
[
List
[
Tuple
[
str
,
Optional
[
str
]]],
Tuple
[
str
,
Optional
[
str
]]]
]
=
None
def
normalize_batch_and_arguments
(
self
):
def
normalize_batch_and_arguments
(
self
):
if
(
if
(
...
@@ -223,9 +229,8 @@ class TokenizedGenerateReqInput:
...
@@ -223,9 +229,8 @@ class TokenizedGenerateReqInput:
# The input embeds
# The input embeds
input_embeds
:
Optional
[
Union
[
List
[
List
[
List
[
float
]]],
List
[
List
[
float
]]]]
=
None
input_embeds
:
Optional
[
Union
[
List
[
List
[
List
[
float
]]],
List
[
List
[
float
]]]]
=
None
# Session id info for continual prompting
# Session info for continual prompting
session_id
:
Optional
[
str
]
=
None
session_params
:
Optional
[
SessionParams
]
=
None
session_rid
:
Optional
[
str
]
=
None
@
dataclass
@
dataclass
...
@@ -468,6 +473,7 @@ class ProfileReq(Enum):
...
@@ -468,6 +473,7 @@ class ProfileReq(Enum):
@
dataclass
@
dataclass
class
OpenSessionReqInput
:
class
OpenSessionReqInput
:
capacity_of_str_len
:
int
capacity_of_str_len
:
int
session_id
:
Optional
[
str
]
=
None
@
dataclass
@
dataclass
...
@@ -477,4 +483,5 @@ class CloseSessionReqInput:
...
@@ -477,4 +483,5 @@ class CloseSessionReqInput:
@
dataclass
@
dataclass
class
OpenSessionReqOutput
:
class
OpenSessionReqOutput
:
session_id
:
str
session_id
:
Optional
[
str
]
success
:
bool
python/sglang/srt/managers/scheduler.py
View file @
e0e09fce
...
@@ -22,7 +22,7 @@ import warnings
...
@@ -22,7 +22,7 @@ import warnings
from
collections
import
deque
from
collections
import
deque
from
concurrent
import
futures
from
concurrent
import
futures
from
types
import
SimpleNamespace
from
types
import
SimpleNamespace
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
psutil
import
psutil
import
setproctitle
import
setproctitle
...
@@ -498,8 +498,10 @@ class Scheduler:
...
@@ -498,8 +498,10 @@ class Scheduler:
else
:
else
:
self
.
stop_profile
()
self
.
stop_profile
()
elif
isinstance
(
recv_req
,
OpenSessionReqInput
):
elif
isinstance
(
recv_req
,
OpenSessionReqInput
):
session_id
=
self
.
open_session
(
recv_req
)
session_id
,
success
=
self
.
open_session
(
recv_req
)
self
.
send_to_tokenizer
.
send_pyobj
(
OpenSessionReqOutput
(
session_id
))
self
.
send_to_tokenizer
.
send_pyobj
(
OpenSessionReqOutput
(
session_id
=
session_id
,
success
=
success
)
)
elif
isinstance
(
recv_req
,
CloseSessionReqInput
):
elif
isinstance
(
recv_req
,
CloseSessionReqInput
):
self
.
close_session
(
recv_req
)
self
.
close_session
(
recv_req
)
else
:
else
:
...
@@ -510,7 +512,11 @@ class Scheduler:
...
@@ -510,7 +512,11 @@ class Scheduler:
recv_req
:
TokenizedGenerateReqInput
,
recv_req
:
TokenizedGenerateReqInput
,
):
):
# Create a new request
# Create a new request
if
recv_req
.
session_id
is
None
or
recv_req
.
session_id
not
in
self
.
sessions
:
if
(
recv_req
.
session_params
is
None
or
recv_req
.
session_params
.
id
is
None
or
recv_req
.
session_params
.
id
not
in
self
.
sessions
):
if
recv_req
.
input_embeds
is
not
None
:
if
recv_req
.
input_embeds
is
not
None
:
# Generate fake input_ids based on the length of input_embeds
# Generate fake input_ids based on the length of input_embeds
...
@@ -532,15 +538,18 @@ class Scheduler:
...
@@ -532,15 +538,18 @@ class Scheduler:
)
)
req
.
tokenizer
=
self
.
tokenizer
req
.
tokenizer
=
self
.
tokenizer
if
recv_req
.
session_id
is
not
None
:
if
(
recv_req
.
session_params
is
not
None
and
recv_req
.
session_params
.
id
is
not
None
):
req
.
finished_reason
=
FINISH_ABORT
(
req
.
finished_reason
=
FINISH_ABORT
(
f
"Invalid request: session id
{
recv_req
.
session_id
}
does not exist"
f
"Invalid request: session id
{
recv_req
.
session_
params
.
id
}
does not exist"
)
)
self
.
waiting_queue
.
append
(
req
)
self
.
waiting_queue
.
append
(
req
)
return
return
else
:
else
:
# Create a new request from a previ
s
ou session
# Create a new request from a previou
s
session
session
=
self
.
sessions
[
recv_req
.
session_id
]
session
=
self
.
sessions
[
recv_req
.
session_
params
.
id
]
req
=
session
.
create_req
(
recv_req
,
self
.
tokenizer
)
req
=
session
.
create_req
(
recv_req
,
self
.
tokenizer
)
if
isinstance
(
req
.
finished_reason
,
FINISH_ABORT
):
if
isinstance
(
req
.
finished_reason
,
FINISH_ABORT
):
self
.
waiting_queue
.
append
(
req
)
self
.
waiting_queue
.
append
(
req
)
...
@@ -1500,16 +1509,20 @@ class Scheduler:
...
@@ -1500,16 +1509,20 @@ class Scheduler:
)
)
logger
.
info
(
"Profiler is done"
)
logger
.
info
(
"Profiler is done"
)
def
open_session
(
self
,
recv_req
:
OpenSessionReqInput
)
->
str
:
def
open_session
(
self
,
recv_req
:
OpenSessionReqInput
)
->
Tuple
[
Optional
[
str
],
bool
]
:
# handle error
# handle error
session_id
=
recv_req
.
session_id
session_id
=
recv_req
.
session_id
if
session_id
in
self
.
sessions
:
if
session_id
in
self
.
sessions
:
logger
.
warning
(
f
"session id
{
session_id
}
already exist, cannot open."
)
logger
.
warning
(
f
"session id
{
session_id
}
already exist, cannot open."
)
return
session_id
,
False
elif
session_id
is
None
:
logger
.
warning
(
f
"session id is None, cannot open."
)
return
session_id
,
False
else
:
else
:
self
.
sessions
[
session_id
]
=
Session
(
self
.
sessions
[
session_id
]
=
Session
(
recv_req
.
capacity_of_str_len
,
session_id
recv_req
.
capacity_of_str_len
,
session_id
)
)
return
session_id
return
session_id
,
True
def
close_session
(
self
,
recv_req
:
CloseSessionReqInput
):
def
close_session
(
self
,
recv_req
:
CloseSessionReqInput
):
# handle error
# handle error
...
...
python/sglang/srt/managers/session_controller.py
View file @
e0e09fce
...
@@ -10,41 +10,116 @@
...
@@ -10,41 +10,116 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
import
logging
import
uuid
import
uuid
from
typing
import
Dict
,
Optional
from
sglang.srt.managers.io_struct
import
TokenizedGenerateReqInput
from
sglang.srt.managers.io_struct
import
TokenizedGenerateReqInput
from
sglang.srt.managers.schedule_batch
import
FINISH_ABORT
,
List
,
Req
from
sglang.srt.managers.schedule_batch
import
Req
class
SessionReqNode
:
def
__init__
(
self
,
req
,
parent
=
None
,
childs
=
None
):
self
.
req
=
req
self
.
parent
=
parent
if
parent
is
not
None
:
parent
.
childs
.
append
(
self
)
self
.
childs
=
[]
if
not
childs
else
childs
def
clear_childs
(
self
,
req_dict
):
for
req_node
in
self
.
childs
:
req_node
.
clear
(
req_dict
)
self
.
childs
=
[]
def
clear
(
self
,
req_dict
):
for
req_node
in
self
.
childs
:
req_node
.
clear
(
req_dict
)
if
self
.
req
.
finished_reason
==
None
:
self
.
req
.
to_abort
=
True
del
req_dict
[
self
.
req
.
rid
]
def
abort
(
self
):
if
self
.
req
.
finished_reason
==
None
:
self
.
req
.
to_abort
=
True
def
__str__
(
self
):
return
self
.
_str_helper
(
self
.
req
.
rid
)
def
_str_helper
(
self
,
prefix
=
""
):
if
len
(
self
.
childs
)
==
0
:
return
prefix
+
"
\n
"
else
:
origin_prefix
=
prefix
prefix
+=
" -- "
+
self
.
childs
[
0
].
req
.
rid
ret
=
self
.
childs
[
0
].
_str_helper
(
prefix
)
for
child
in
self
.
childs
[
1
:]:
prefix
=
" "
*
len
(
origin_prefix
)
+
" \- "
+
child
.
req
.
rid
ret
+=
child
.
_str_helper
(
prefix
)
return
ret
class
Session
:
class
Session
:
def
__init__
(
self
,
capacity_of_str_len
:
int
,
session_id
:
str
=
None
):
def
__init__
(
self
,
capacity_of_str_len
:
int
,
session_id
:
Optional
[
str
]
=
None
):
self
.
session_id
=
session_id
if
session_id
is
not
None
else
uuid
.
uuid4
().
hex
self
.
session_id
=
session_id
if
session_id
is
not
None
else
uuid
.
uuid4
().
hex
self
.
capacity_of_str_len
=
capacity_of_str_len
self
.
capacity_of_str_len
=
capacity_of_str_len
self
.
req
s
:
List
[
Req
]
=
[]
self
.
req
_nodes
:
Dict
[
str
,
SessionReqNode
]
=
{}
def
create_req
(
self
,
req
:
TokenizedGenerateReqInput
,
tokenizer
):
def
create_req
(
self
,
req
:
TokenizedGenerateReqInput
,
tokenizer
):
if
req
.
session_rid
is
not
None
:
assert
req
.
session_params
is
not
None
while
len
(
self
.
reqs
)
>
0
:
session_params
=
req
.
session_params
if
self
.
reqs
[
-
1
].
rid
==
req
.
session_rid
:
break
last_req_node
=
None
self
.
reqs
=
self
.
reqs
[:
-
1
]
last_req
=
None
abort
=
False
if
session_params
.
replace
:
if
session_params
.
rid
is
None
:
for
_
,
req_node
in
self
.
req_nodes
.
items
():
req_node
.
clear
(
self
.
req_nodes
)
else
:
if
session_params
.
rid
not
in
self
.
req_nodes
:
abort
=
True
else
:
last_req_node
=
self
.
req_nodes
[
session_params
.
rid
]
last_req_node
.
abort
()
last_req
=
last_req_node
.
req
last_req_node
.
clear_childs
(
self
.
req_nodes
)
else
:
else
:
self
.
reqs
=
[]
if
session_params
.
rid
is
not
None
:
if
len
(
self
.
reqs
)
>
0
:
if
session_params
.
rid
not
in
self
.
req_nodes
:
abort
=
True
else
:
last_req_node
=
self
.
req_nodes
[
session_params
.
rid
]
last_req
=
last_req_node
.
req
if
not
last_req
.
finished
():
logging
.
warning
(
"The request in a session is appending to a request that hasn't finished."
)
abort
=
True
if
last_req
is
not
None
:
# trim bos token if it is an append
if
req
.
input_ids
[
0
]
==
tokenizer
.
bos_token_id
:
req
.
input_ids
=
req
.
input_ids
[
1
:]
input_ids
=
(
input_ids
=
(
self
.
reqs
[
-
1
].
origin_input_ids
last_req
.
origin_input_ids
+
self
.
reqs
[
-
1
].
output_ids
[
+
last_req
.
output_ids
[:
last_req
.
sampling_params
.
max_new_tokens
]
:
self
.
reqs
[
-
1
].
sampling_params
.
max_new_tokens
]
+
req
.
input_ids
)
)
if
session_params
.
offset
and
session_params
.
offset
!=
0
:
input_ids
=
input_ids
[:
session_params
.
offset
]
+
req
.
input_ids
else
:
input_ids
+=
req
.
input_ids
input_ids_unpadded
=
(
input_ids_unpadded
=
(
self
.
reqs
[
-
1
].
origin_input_ids_unpadded
last_req
.
origin_input_ids_unpadded
+
self
.
reqs
[
-
1
].
output_ids
[
+
last_req
.
output_ids
[:
last_req
.
sampling_params
.
max_new_tokens
]
:
self
.
reqs
[
-
1
].
sampling_params
.
max_new_tokens
]
+
req
.
input_ids
)
)
if
session_params
.
offset
and
session_params
.
offset
!=
0
:
input_ids_unpadded
=
(
input_ids_unpadded
[:
session_params
.
offset
]
+
req
.
input_ids
)
else
:
input_ids_unpadded
+=
req
.
input_ids
else
:
else
:
input_ids
=
req
.
input_ids
input_ids
=
req
.
input_ids
input_ids_unpadded
=
req
.
input_ids
input_ids_unpadded
=
req
.
input_ids
...
@@ -57,13 +132,13 @@ class Session:
...
@@ -57,13 +132,13 @@ class Session:
lora_path
=
req
.
lora_path
,
lora_path
=
req
.
lora_path
,
session_id
=
self
.
session_id
,
session_id
=
self
.
session_id
,
)
)
if
l
en
(
self
.
reqs
)
>
0
:
if
l
ast_req
is
not
None
:
new_req
.
image_inputs
=
self
.
reqs
[
-
1
]
.
image_inputs
new_req
.
image_inputs
=
last_req
.
image_inputs
new_req
.
tokenizer
=
tokenizer
new_req
.
tokenizer
=
tokenizer
if
req
.
session_rid
is
not
None
and
len
(
self
.
reqs
)
==
0
:
if
abort
:
new_req
.
finished_reason
=
FINISH_ABORT
(
new_req
.
to_abort
=
True
f
"Invalid request: requested session rid
{
req
.
session_rid
}
does not exist in the session history"
)
else
:
else
:
self
.
reqs
.
append
(
new_req
)
new_req_node
=
SessionReqNode
(
new_req
,
last_req_node
)
self
.
req_nodes
[
req
.
rid
]
=
new_req_node
return
new_req
return
new_req
python/sglang/srt/managers/tokenizer_manager.py
View file @
e0e09fce
...
@@ -53,6 +53,7 @@ from sglang.srt.managers.io_struct import (
...
@@ -53,6 +53,7 @@ from sglang.srt.managers.io_struct import (
OpenSessionReqInput
,
OpenSessionReqInput
,
OpenSessionReqOutput
,
OpenSessionReqOutput
,
ProfileReq
,
ProfileReq
,
SessionParams
,
TokenizedEmbeddingReqInput
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
TokenizedGenerateReqInput
,
UpdateWeightFromDiskReqInput
,
UpdateWeightFromDiskReqInput
,
...
@@ -264,8 +265,9 @@ class TokenizerManager:
...
@@ -264,8 +265,9 @@ class TokenizerManager:
return_logprob
=
obj
.
return_logprob
return_logprob
=
obj
.
return_logprob
logprob_start_len
=
obj
.
logprob_start_len
logprob_start_len
=
obj
.
logprob_start_len
top_logprobs_num
=
obj
.
top_logprobs_num
top_logprobs_num
=
obj
.
top_logprobs_num
session_id
=
obj
.
session
[
0
]
if
obj
.
session
else
None
session_params
=
(
session_rid
=
obj
.
session
[
1
]
if
obj
.
session
else
None
SessionParams
(
**
obj
.
session_params
)
if
obj
.
session_params
else
None
)
if
obj
.
input_ids
is
not
None
and
len
(
input_ids
)
>=
self
.
context_len
:
if
obj
.
input_ids
is
not
None
and
len
(
input_ids
)
>=
self
.
context_len
:
raise
ValueError
(
raise
ValueError
(
...
@@ -292,8 +294,7 @@ class TokenizerManager:
...
@@ -292,8 +294,7 @@ class TokenizerManager:
obj
.
stream
,
obj
.
stream
,
lora_path
=
obj
.
lora_path
,
lora_path
=
obj
.
lora_path
,
input_embeds
=
input_embeds
,
input_embeds
=
input_embeds
,
session_id
=
session_id
,
session_params
=
session_params
,
session_rid
=
session_rid
,
)
)
elif
isinstance
(
obj
,
EmbeddingReqInput
):
elif
isinstance
(
obj
,
EmbeddingReqInput
):
tokenized_obj
=
TokenizedEmbeddingReqInput
(
tokenized_obj
=
TokenizedEmbeddingReqInput
(
...
@@ -552,12 +553,16 @@ class TokenizerManager:
...
@@ -552,12 +553,16 @@ class TokenizerManager:
):
):
self
.
auto_create_handle_loop
()
self
.
auto_create_handle_loop
()
session_id
=
uuid
.
uuid4
().
hex
if
obj
.
session_id
is
None
:
obj
.
session_id
=
session_id
obj
.
session_id
=
uuid
.
uuid4
().
hex
elif
obj
.
session_id
in
self
.
session_futures
:
return
None
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
self
.
session_futures
[
session_id
]
=
asyncio
.
Future
()
session_id
=
await
self
.
session_futures
[
session_id
]
self
.
session_futures
[
obj
.
session_id
]
=
asyncio
.
Future
()
del
self
.
session_futures
[
session_id
]
session_id
=
await
self
.
session_futures
[
obj
.
session_id
]
del
self
.
session_futures
[
obj
.
session_id
]
return
session_id
return
session_id
async
def
close_session
(
async
def
close_session
(
...
@@ -709,7 +714,7 @@ class TokenizerManager:
...
@@ -709,7 +714,7 @@ class TokenizerManager:
)
)
elif
isinstance
(
recv_obj
,
OpenSessionReqOutput
):
elif
isinstance
(
recv_obj
,
OpenSessionReqOutput
):
self
.
session_futures
[
recv_obj
.
session_id
].
set_result
(
self
.
session_futures
[
recv_obj
.
session_id
].
set_result
(
recv_obj
.
session_id
recv_obj
.
session_id
if
recv_obj
.
success
else
None
)
)
elif
isinstance
(
recv_obj
,
UpdateWeightFromDiskReqOutput
):
elif
isinstance
(
recv_obj
,
UpdateWeightFromDiskReqOutput
):
if
self
.
server_args
.
dp_size
==
1
:
if
self
.
server_args
.
dp_size
==
1
:
...
...
python/sglang/srt/server.py
View file @
e0e09fce
...
@@ -259,6 +259,10 @@ async def open_session(obj: OpenSessionReqInput, request: Request):
...
@@ -259,6 +259,10 @@ async def open_session(obj: OpenSessionReqInput, request: Request):
"""Open a session, and return its unique session id."""
"""Open a session, and return its unique session id."""
try
:
try
:
session_id
=
await
tokenizer_manager
.
open_session
(
obj
,
request
)
session_id
=
await
tokenizer_manager
.
open_session
(
obj
,
request
)
if
session_id
is
None
:
raise
Exception
(
"Failed to open the session. Check if a session with the same id is still open."
)
return
session_id
return
session_id
except
Exception
as
e
:
except
Exception
as
e
:
return
_create_error_response
(
e
)
return
_create_error_response
(
e
)
...
...
test/srt/test_session_control.py
View file @
e0e09fce
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment