Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5942dfc0
Unverified
Commit
5942dfc0
authored
Nov 20, 2024
by
Ying Sheng
Committed by
GitHub
Nov 20, 2024
Browse files
[feat] Add session control (#2073)
parent
63a395b9
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
348 additions
and
8 deletions
+348
-8
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+1
-0
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+27
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+3
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+58
-8
python/sglang/srt/managers/session_controller.py
python/sglang/srt/managers/session_controller.py
+62
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+38
-0
python/sglang/srt/server.py
python/sglang/srt/server.py
+26
-0
test/srt/test_session_id.py
test/srt/test_session_id.py
+133
-0
No files found.
python/sglang/srt/managers/detokenizer_manager.py
View file @
5942dfc0
...
...
@@ -175,6 +175,7 @@ class DetokenizerManager:
output_strs
=
output_strs
,
meta_info
=
recv_obj
.
meta_info
,
finished_reason
=
recv_obj
.
finished_reason
,
session_ids
=
recv_obj
.
session_ids
,
)
)
...
...
python/sglang/srt/managers/io_struct.py
View file @
5942dfc0
...
...
@@ -56,6 +56,10 @@ class GenerateReqInput:
# LoRA related
lora_path
:
Optional
[
Union
[
List
[
Optional
[
str
]],
Optional
[
str
]]]
=
None
# Session id info for continual prompting
session_id
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
session_rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
def
normalize_batch_and_arguments
(
self
):
if
(
self
.
text
is
None
and
self
.
input_ids
is
None
)
or
(
self
.
text
is
not
None
and
self
.
input_ids
is
not
None
...
...
@@ -200,6 +204,10 @@ class TokenizedGenerateReqInput:
# LoRA related
lora_path
:
Optional
[
str
]
=
None
# None means just use the base model
# Session id info for continual prompting
session_id
:
Optional
[
int
]
=
None
session_rid
:
Optional
[
str
]
=
None
@
dataclass
class
EmbeddingReqInput
:
...
...
@@ -293,6 +301,8 @@ class BatchTokenIDOut:
meta_info
:
List
[
Dict
]
finished_reason
:
List
[
BaseFinishReason
]
no_stop_trim
:
List
[
bool
]
# The updated session unique id
session_ids
:
List
[
str
]
@
dataclass
...
...
@@ -305,6 +315,8 @@ class BatchStrOut:
meta_info
:
List
[
Dict
]
# The finish reason
finished_reason
:
List
[
BaseFinishReason
]
# The update session unique id
session_ids
:
List
[
str
]
@
dataclass
...
...
@@ -357,3 +369,18 @@ class GetMemPoolSizeReq:
@
dataclass
class
GetMemPoolSizeReqOutput
:
size
:
int
@
dataclass
class
OpenSessionReqInput
:
capacity_of_str_len
:
int
@
dataclass
class
CloseSessionReqInput
:
session_id
:
str
@
dataclass
class
OpenSessionReqOutput
:
session_id
:
str
python/sglang/srt/managers/schedule_batch.py
View file @
5942dfc0
...
...
@@ -180,6 +180,7 @@ class Req:
origin_input_ids
:
Tuple
[
int
],
sampling_params
:
SamplingParams
,
lora_path
:
Optional
[
str
]
=
None
,
session_id
:
Optional
[
str
]
=
None
,
):
# Input and output info
self
.
rid
=
rid
...
...
@@ -188,6 +189,8 @@ class Req:
self
.
origin_input_ids
=
origin_input_ids
self
.
output_ids
=
[]
# Each decode stage's output ids
self
.
fill_ids
=
None
# fill_ids = origin_input_ids + output_ids
self
.
session_id
=
session_id
self
.
sampling_params
=
sampling_params
self
.
lora_path
=
lora_path
...
...
python/sglang/srt/managers/scheduler.py
View file @
5942dfc0
...
...
@@ -37,9 +37,12 @@ from sglang.srt.managers.io_struct import (
AbortReq
,
BatchEmbeddingOut
,
BatchTokenIDOut
,
CloseSessionReqInput
,
FlushCacheReq
,
GetMemPoolSizeReq
,
GetMemPoolSizeReqOutput
,
OpenSessionReqInput
,
OpenSessionReqOutput
,
ProfileReq
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
...
...
@@ -59,6 +62,7 @@ from sglang.srt.managers.schedule_policy import (
PrefillAdder
,
SchedulePolicy
,
)
from
sglang.srt.managers.session_controller
import
Session
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.managers.tp_worker_overlap_thread
import
TpModelWorkerClient
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
...
...
@@ -106,6 +110,9 @@ class Scheduler:
self
.
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
self
.
enable_metrics
=
server_args
.
enable_metrics
# Session info
self
.
sessions
=
{}
# Init inter-process communication
context
=
zmq
.
Context
(
2
)
...
...
@@ -509,6 +516,11 @@ class Scheduler:
self
.
start_profile
()
else
:
self
.
stop_profile
()
elif
isinstance
(
recv_req
,
OpenSessionReqInput
):
session_id
=
self
.
open_session
(
recv_req
)
self
.
send_to_tokenizer
.
send_pyobj
(
OpenSessionReqOutput
(
session_id
))
elif
isinstance
(
recv_req
,
CloseSessionReqInput
):
self
.
close_session
(
recv_req
)
elif
isinstance
(
recv_req
,
GetMemPoolSizeReq
):
self
.
send_to_tokenizer
.
send_pyobj
(
GetMemPoolSizeReqOutput
(
self
.
max_total_num_tokens
)
...
...
@@ -520,6 +532,7 @@ class Scheduler:
self
,
recv_req
:
TokenizedGenerateReqInput
,
):
if
recv_req
.
session_id
is
None
or
recv_req
.
session_id
not
in
self
.
sessions
:
req
=
Req
(
recv_req
.
rid
,
recv_req
.
input_text
,
...
...
@@ -528,6 +541,21 @@ class Scheduler:
lora_path
=
recv_req
.
lora_path
,
)
req
.
tokenizer
=
self
.
tokenizer
if
recv_req
.
session_id
is
not
None
:
req
.
finished_reason
=
FINISH_ABORT
(
f
"Invalid request: session id
{
recv_req
.
session_id
}
does not exist"
)
self
.
waiting_queue
.
append
(
req
)
return
else
:
# Handle sessions
session
=
self
.
sessions
[
recv_req
.
session_id
]
req
,
new_session_id
=
session
.
create_req
(
recv_req
,
self
.
tokenizer
)
del
self
.
sessions
[
recv_req
.
session_id
]
self
.
sessions
[
new_session_id
]
=
session
if
isinstance
(
req
.
finished_reason
,
FINISH_ABORT
):
self
.
waiting_queue
.
append
(
req
)
return
# Image inputs
if
recv_req
.
image_inputs
is
not
None
:
...
...
@@ -1151,6 +1179,7 @@ class Scheduler:
output_skip_special_tokens
=
[]
output_spaces_between_special_tokens
=
[]
output_no_stop_trim
=
[]
output_session_ids
=
[]
else
:
# embedding or reward model
output_embeddings
=
[]
...
...
@@ -1178,6 +1207,7 @@ class Scheduler:
req
.
sampling_params
.
spaces_between_special_tokens
)
output_no_stop_trim
.
append
(
req
.
sampling_params
.
no_stop_trim
)
output_session_ids
.
append
(
req
.
session_id
)
meta_info
=
{
"prompt_tokens"
:
len
(
req
.
origin_input_ids
),
...
...
@@ -1228,6 +1258,7 @@ class Scheduler:
output_meta_info
,
output_finished_reason
,
output_no_stop_trim
,
output_session_ids
,
)
)
else
:
# embedding or reward model
...
...
@@ -1330,6 +1361,25 @@ class Scheduler:
)
logger
.
info
(
"Profiler is done"
)
def
open_session
(
self
,
recv_req
:
OpenSessionReqInput
)
->
str
:
# handle error
session_id
=
recv_req
.
session_id
if
session_id
in
self
.
sessions
:
logger
.
warning
(
f
"session id
{
session_id
}
already exist, cannot open."
)
else
:
self
.
sessions
[
session_id
]
=
Session
(
recv_req
.
capacity_of_str_len
,
session_id
)
return
session_id
def
close_session
(
self
,
recv_req
:
CloseSessionReqInput
):
# handle error
session_id
=
recv_req
.
session_id
if
session_id
not
in
self
.
sessions
:
logger
.
warning
(
f
"session id
{
session_id
}
does not exist, cannot delete."
)
else
:
del
self
.
sessions
[
session_id
]
def
run_scheduler_process
(
server_args
:
ServerArgs
,
...
...
python/sglang/srt/managers/session_controller.py
0 → 100644
View file @
5942dfc0
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import
copy
import
uuid
from
dataclasses
import
dataclass
from
typing
import
Optional
from
sglang.srt.managers.io_struct
import
TokenizedGenerateReqInput
from
sglang.srt.managers.schedule_batch
import
FINISH_ABORT
,
List
,
Req
class
Session
:
def
__init__
(
self
,
capacity_of_str_len
:
int
,
session_id
:
str
=
None
):
self
.
session_id
=
session_id
if
session_id
is
not
None
else
uuid
.
uuid4
().
hex
self
.
capacity_of_str_len
=
capacity_of_str_len
self
.
reqs
:
List
[
Req
]
=
[]
def
create_req
(
self
,
req
:
TokenizedGenerateReqInput
,
tokenizer
):
# renew session id
self
.
session_id
=
uuid
.
uuid4
().
hex
if
req
.
session_rid
is
not
None
:
while
len
(
self
.
reqs
)
>
0
:
if
self
.
reqs
[
-
1
].
rid
==
req
.
session_rid
:
break
self
.
reqs
=
self
.
reqs
[:
-
1
]
if
len
(
self
.
reqs
)
>
0
:
input_ids
=
(
self
.
reqs
[
-
1
].
origin_input_ids
+
self
.
reqs
[
-
1
].
output_ids
[
:
self
.
reqs
[
-
1
].
sampling_params
.
max_new_tokens
]
+
req
.
input_ids
)
else
:
input_ids
=
req
.
input_ids
new_req
=
Req
(
req
.
rid
,
None
,
input_ids
,
req
.
sampling_params
,
lora_path
=
req
.
lora_path
,
session_id
=
self
.
session_id
,
)
new_req
.
tokenizer
=
tokenizer
if
req
.
session_rid
is
not
None
and
len
(
self
.
reqs
)
==
0
:
new_req
.
finished_reason
=
FINISH_ABORT
(
f
"Invalid request: requested session rid
{
req
.
session_rid
}
does not exist in the session history"
)
else
:
self
.
reqs
.
append
(
new_req
)
return
new_req
,
self
.
session_id
python/sglang/srt/managers/tokenizer_manager.py
View file @
5942dfc0
...
...
@@ -23,6 +23,7 @@ import os
import
signal
import
sys
import
time
import
uuid
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
fastapi
...
...
@@ -42,11 +43,14 @@ from sglang.srt.managers.io_struct import (
BatchEmbeddingOut
,
BatchStrOut
,
BatchTokenIDOut
,
CloseSessionReqInput
,
EmbeddingReqInput
,
FlushCacheReq
,
GenerateReqInput
,
GetMemPoolSizeReq
,
GetMemPoolSizeReqOutput
,
OpenSessionReqInput
,
OpenSessionReqOutput
,
ProfileReq
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
...
...
@@ -146,6 +150,9 @@ class TokenizerManager:
self
.
model_update_lock
=
asyncio
.
Lock
()
self
.
model_update_result
=
None
# For session info
self
.
session_futures
=
{}
# session_id -> asyncio event
# Others
self
.
gracefully_exit
=
False
...
...
@@ -211,6 +218,8 @@ class TokenizerManager:
return_logprob
=
obj
.
return_logprob
logprob_start_len
=
obj
.
logprob_start_len
top_logprobs_num
=
obj
.
top_logprobs_num
session_id
=
obj
.
session_id
session_rid
=
obj
.
session_rid
if
len
(
input_ids
)
>=
self
.
context_len
:
raise
ValueError
(
...
...
@@ -236,6 +245,8 @@ class TokenizerManager:
top_logprobs_num
,
obj
.
stream
,
obj
.
lora_path
,
session_id
=
session_id
,
session_rid
=
session_rid
,
)
elif
isinstance
(
obj
,
EmbeddingReqInput
):
tokenized_obj
=
TokenizedEmbeddingReqInput
(
...
...
@@ -451,6 +462,26 @@ class TokenizerManager:
else
:
return
False
,
"Another update is in progress. Please try again later."
async
def
open_session
(
self
,
obj
:
OpenSessionReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
):
if
self
.
to_create_loop
:
self
.
create_handle_loop
()
session_id
=
uuid
.
uuid4
().
hex
obj
.
session_id
=
session_id
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
self
.
session_futures
[
session_id
]
=
asyncio
.
Future
()
session_id
=
await
self
.
session_futures
[
session_id
]
del
self
.
session_futures
[
session_id
]
return
session_id
async
def
close_session
(
self
,
obj
:
CloseSessionReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
):
assert
not
self
.
to_create_loop
,
"close session should not be the first request"
await
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
def
create_abort_task
(
self
,
obj
:
GenerateReqInput
):
# Abort the request if the client is disconnected.
async
def
abort_request
():
...
...
@@ -521,6 +552,11 @@ class TokenizerManager:
if
len
(
self
.
mem_pool_size_tmp
)
==
self
.
server_args
.
dp_size
:
self
.
mem_pool_size
.
set_result
(
self
.
mem_pool_size_tmp
)
continue
elif
isinstance
(
recv_obj
,
OpenSessionReqOutput
):
self
.
session_futures
[
recv_obj
.
session_id
].
set_result
(
recv_obj
.
session_id
)
continue
assert
isinstance
(
recv_obj
,
(
BatchStrOut
,
BatchEmbeddingOut
,
BatchTokenIDOut
)
...
...
@@ -536,11 +572,13 @@ class TokenizerManager:
out_dict
=
{
"text"
:
recv_obj
.
output_strs
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
"session_id"
:
recv_obj
.
session_ids
[
i
],
}
elif
isinstance
(
recv_obj
,
BatchTokenIDOut
):
out_dict
=
{
"token_ids"
:
recv_obj
.
output_ids
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
"session_id"
:
recv_obj
.
session_ids
[
i
],
}
else
:
assert
isinstance
(
recv_obj
,
BatchEmbeddingOut
)
...
...
python/sglang/srt/server.py
View file @
5942dfc0
...
...
@@ -50,8 +50,10 @@ from sglang.srt.managers.data_parallel_controller import (
)
from
sglang.srt.managers.detokenizer_manager
import
run_detokenizer_process
from
sglang.srt.managers.io_struct
import
(
CloseSessionReqInput
,
EmbeddingReqInput
,
GenerateReqInput
,
OpenSessionReqInput
,
UpdateWeightReqInput
,
)
from
sglang.srt.managers.scheduler
import
run_scheduler_process
...
...
@@ -215,6 +217,30 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request):
)
@
app
.
api_route
(
"/open_session"
,
methods
=
[
"GET"
,
"POST"
])
async
def
open_session
(
obj
:
OpenSessionReqInput
,
request
:
Request
):
"""Open a session, and return its unique session id."""
try
:
session_id
=
await
tokenizer_manager
.
open_session
(
obj
,
request
)
return
session_id
except
Exception
as
e
:
return
ORJSONResponse
(
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
)
@
app
.
api_route
(
"/close_session"
,
methods
=
[
"GET"
,
"POST"
])
async
def
close_session
(
obj
:
CloseSessionReqInput
,
request
:
Request
):
"""Close the session"""
try
:
await
tokenizer_manager
.
close_session
(
obj
,
request
)
return
Response
(
status_code
=
200
)
except
Exception
as
e
:
return
ORJSONResponse
(
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
)
@
time_func_latency
async
def
generate_request
(
obj
:
GenerateReqInput
,
request
:
Request
):
"""Handle a generate request."""
...
...
test/srt/test_session_id.py
0 → 100644
View file @
5942dfc0
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# FIXME: Make it a CI test
import
requests
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
url
=
"http://localhost:30000"
# Open a session
response
=
requests
.
post
(
url
+
"/open_session"
,
json
=
{
"capacity_of_str_len"
:
1000
},
)
session_id
=
response
.
json
()
print
(
"session_id"
,
session_id
,
"
\n
"
)
# Prefill only
prompt
=
"chunk 1"
response
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
"text"
:
prompt
,
"session_id"
:
session_id
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
0
,
},
},
)
print
(
response
.
json
(),
"
\n
"
)
session_id
=
response
.
json
()[
"session_id"
]
# Generate
prompt
=
"Chunk 2"
response
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
"text"
:
prompt
,
"session_id"
:
session_id
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
16
,
},
},
)
print
(
response
.
json
(),
"
\n
"
)
session_id
=
response
.
json
()[
"session_id"
]
rid
=
response
.
json
()[
"meta_info"
][
"id"
]
# Generate
prompt
=
"Chunk 3"
response
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
"text"
:
prompt
,
"session_id"
:
session_id
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
2
,
},
},
)
print
(
response
.
json
(),
"
\n
"
)
session_id
=
response
.
json
()[
"session_id"
]
rid_to_del
=
response
.
json
()[
"meta_info"
][
"id"
]
# Interrupt and re-generate
prompt
=
"Chunk 4"
response
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
"text"
:
prompt
,
"session_id"
:
session_id
,
"session_rid"
:
rid
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
16
,
},
},
)
print
(
response
.
json
(),
"
\n
"
)
session_id
=
response
.
json
()[
"session_id"
]
# Query a session based on a deleted request, should see finish reason abort
prompt
=
"Chunk 4"
response
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
"text"
:
prompt
,
"session_id"
:
session_id
,
"session_rid"
:
rid_to_del
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
16
,
},
},
)
print
(
response
.
json
(),
"
\n
"
)
# Close session
ret
=
requests
.
post
(
url
+
"/close_session"
,
json
=
{
"session_id"
:
session_id
},
)
print
(
ret
,
"
\n
"
)
# Query a deleted session, should see finish reason abort
prompt
=
"chunk 1"
response
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
"text"
:
prompt
,
"session_id"
:
session_id
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
0
,
},
},
)
print
(
response
.
json
(),
"
\n
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment