Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5942dfc0
Unverified
Commit
5942dfc0
authored
Nov 20, 2024
by
Ying Sheng
Committed by
GitHub
Nov 20, 2024
Browse files
[feat] Add session control (#2073)
parent
63a395b9
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
348 additions
and
8 deletions
+348
-8
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+1
-0
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+27
-0
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+3
-0
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+58
-8
python/sglang/srt/managers/session_controller.py
python/sglang/srt/managers/session_controller.py
+62
-0
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+38
-0
python/sglang/srt/server.py
python/sglang/srt/server.py
+26
-0
test/srt/test_session_id.py
test/srt/test_session_id.py
+133
-0
No files found.
python/sglang/srt/managers/detokenizer_manager.py
View file @
5942dfc0
...
@@ -175,6 +175,7 @@ class DetokenizerManager:
...
@@ -175,6 +175,7 @@ class DetokenizerManager:
output_strs
=
output_strs
,
output_strs
=
output_strs
,
meta_info
=
recv_obj
.
meta_info
,
meta_info
=
recv_obj
.
meta_info
,
finished_reason
=
recv_obj
.
finished_reason
,
finished_reason
=
recv_obj
.
finished_reason
,
session_ids
=
recv_obj
.
session_ids
,
)
)
)
)
...
...
python/sglang/srt/managers/io_struct.py
View file @
5942dfc0
...
@@ -56,6 +56,10 @@ class GenerateReqInput:
...
@@ -56,6 +56,10 @@ class GenerateReqInput:
# LoRA related
# LoRA related
lora_path
:
Optional
[
Union
[
List
[
Optional
[
str
]],
Optional
[
str
]]]
=
None
lora_path
:
Optional
[
Union
[
List
[
Optional
[
str
]],
Optional
[
str
]]]
=
None
# Session id info for continual prompting
session_id
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
session_rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
def
normalize_batch_and_arguments
(
self
):
def
normalize_batch_and_arguments
(
self
):
if
(
self
.
text
is
None
and
self
.
input_ids
is
None
)
or
(
if
(
self
.
text
is
None
and
self
.
input_ids
is
None
)
or
(
self
.
text
is
not
None
and
self
.
input_ids
is
not
None
self
.
text
is
not
None
and
self
.
input_ids
is
not
None
...
@@ -200,6 +204,10 @@ class TokenizedGenerateReqInput:
...
@@ -200,6 +204,10 @@ class TokenizedGenerateReqInput:
# LoRA related
# LoRA related
lora_path
:
Optional
[
str
]
=
None
# None means just use the base model
lora_path
:
Optional
[
str
]
=
None
# None means just use the base model
# Session id info for continual prompting
session_id
:
Optional
[
int
]
=
None
session_rid
:
Optional
[
str
]
=
None
@
dataclass
@
dataclass
class
EmbeddingReqInput
:
class
EmbeddingReqInput
:
...
@@ -293,6 +301,8 @@ class BatchTokenIDOut:
...
@@ -293,6 +301,8 @@ class BatchTokenIDOut:
meta_info
:
List
[
Dict
]
meta_info
:
List
[
Dict
]
finished_reason
:
List
[
BaseFinishReason
]
finished_reason
:
List
[
BaseFinishReason
]
no_stop_trim
:
List
[
bool
]
no_stop_trim
:
List
[
bool
]
# The updated session unique id
session_ids
:
List
[
str
]
@
dataclass
@
dataclass
...
@@ -305,6 +315,8 @@ class BatchStrOut:
...
@@ -305,6 +315,8 @@ class BatchStrOut:
meta_info
:
List
[
Dict
]
meta_info
:
List
[
Dict
]
# The finish reason
# The finish reason
finished_reason
:
List
[
BaseFinishReason
]
finished_reason
:
List
[
BaseFinishReason
]
# The update session unique id
session_ids
:
List
[
str
]
@
dataclass
@
dataclass
...
@@ -357,3 +369,18 @@ class GetMemPoolSizeReq:
...
@@ -357,3 +369,18 @@ class GetMemPoolSizeReq:
@
dataclass
@
dataclass
class
GetMemPoolSizeReqOutput
:
class
GetMemPoolSizeReqOutput
:
size
:
int
size
:
int
@
dataclass
class
OpenSessionReqInput
:
capacity_of_str_len
:
int
@
dataclass
class
CloseSessionReqInput
:
session_id
:
str
@
dataclass
class
OpenSessionReqOutput
:
session_id
:
str
python/sglang/srt/managers/schedule_batch.py
View file @
5942dfc0
...
@@ -180,6 +180,7 @@ class Req:
...
@@ -180,6 +180,7 @@ class Req:
origin_input_ids
:
Tuple
[
int
],
origin_input_ids
:
Tuple
[
int
],
sampling_params
:
SamplingParams
,
sampling_params
:
SamplingParams
,
lora_path
:
Optional
[
str
]
=
None
,
lora_path
:
Optional
[
str
]
=
None
,
session_id
:
Optional
[
str
]
=
None
,
):
):
# Input and output info
# Input and output info
self
.
rid
=
rid
self
.
rid
=
rid
...
@@ -188,6 +189,8 @@ class Req:
...
@@ -188,6 +189,8 @@ class Req:
self
.
origin_input_ids
=
origin_input_ids
self
.
origin_input_ids
=
origin_input_ids
self
.
output_ids
=
[]
# Each decode stage's output ids
self
.
output_ids
=
[]
# Each decode stage's output ids
self
.
fill_ids
=
None
# fill_ids = origin_input_ids + output_ids
self
.
fill_ids
=
None
# fill_ids = origin_input_ids + output_ids
self
.
session_id
=
session_id
self
.
sampling_params
=
sampling_params
self
.
sampling_params
=
sampling_params
self
.
lora_path
=
lora_path
self
.
lora_path
=
lora_path
...
...
python/sglang/srt/managers/scheduler.py
View file @
5942dfc0
...
@@ -37,9 +37,12 @@ from sglang.srt.managers.io_struct import (
...
@@ -37,9 +37,12 @@ from sglang.srt.managers.io_struct import (
AbortReq
,
AbortReq
,
BatchEmbeddingOut
,
BatchEmbeddingOut
,
BatchTokenIDOut
,
BatchTokenIDOut
,
CloseSessionReqInput
,
FlushCacheReq
,
FlushCacheReq
,
GetMemPoolSizeReq
,
GetMemPoolSizeReq
,
GetMemPoolSizeReqOutput
,
GetMemPoolSizeReqOutput
,
OpenSessionReqInput
,
OpenSessionReqOutput
,
ProfileReq
,
ProfileReq
,
TokenizedEmbeddingReqInput
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
TokenizedGenerateReqInput
,
...
@@ -59,6 +62,7 @@ from sglang.srt.managers.schedule_policy import (
...
@@ -59,6 +62,7 @@ from sglang.srt.managers.schedule_policy import (
PrefillAdder
,
PrefillAdder
,
SchedulePolicy
,
SchedulePolicy
,
)
)
from
sglang.srt.managers.session_controller
import
Session
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.managers.tp_worker
import
TpModelWorker
from
sglang.srt.managers.tp_worker_overlap_thread
import
TpModelWorkerClient
from
sglang.srt.managers.tp_worker_overlap_thread
import
TpModelWorkerClient
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
from
sglang.srt.mem_cache.chunk_cache
import
ChunkCache
...
@@ -106,6 +110,9 @@ class Scheduler:
...
@@ -106,6 +110,9 @@ class Scheduler:
self
.
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
self
.
skip_tokenizer_init
=
server_args
.
skip_tokenizer_init
self
.
enable_metrics
=
server_args
.
enable_metrics
self
.
enable_metrics
=
server_args
.
enable_metrics
# Session info
self
.
sessions
=
{}
# Init inter-process communication
# Init inter-process communication
context
=
zmq
.
Context
(
2
)
context
=
zmq
.
Context
(
2
)
...
@@ -509,6 +516,11 @@ class Scheduler:
...
@@ -509,6 +516,11 @@ class Scheduler:
self
.
start_profile
()
self
.
start_profile
()
else
:
else
:
self
.
stop_profile
()
self
.
stop_profile
()
elif
isinstance
(
recv_req
,
OpenSessionReqInput
):
session_id
=
self
.
open_session
(
recv_req
)
self
.
send_to_tokenizer
.
send_pyobj
(
OpenSessionReqOutput
(
session_id
))
elif
isinstance
(
recv_req
,
CloseSessionReqInput
):
self
.
close_session
(
recv_req
)
elif
isinstance
(
recv_req
,
GetMemPoolSizeReq
):
elif
isinstance
(
recv_req
,
GetMemPoolSizeReq
):
self
.
send_to_tokenizer
.
send_pyobj
(
self
.
send_to_tokenizer
.
send_pyobj
(
GetMemPoolSizeReqOutput
(
self
.
max_total_num_tokens
)
GetMemPoolSizeReqOutput
(
self
.
max_total_num_tokens
)
...
@@ -520,14 +532,30 @@ class Scheduler:
...
@@ -520,14 +532,30 @@ class Scheduler:
self
,
self
,
recv_req
:
TokenizedGenerateReqInput
,
recv_req
:
TokenizedGenerateReqInput
,
):
):
req
=
Req
(
if
recv_req
.
session_id
is
None
or
recv_req
.
session_id
not
in
self
.
sessions
:
recv_req
.
rid
,
req
=
Req
(
recv_req
.
input_text
,
recv_req
.
rid
,
recv_req
.
input_ids
,
recv_req
.
input_text
,
recv_req
.
sampling_params
,
recv_req
.
input_ids
,
lora_path
=
recv_req
.
lora_path
,
recv_req
.
sampling_params
,
)
lora_path
=
recv_req
.
lora_path
,
req
.
tokenizer
=
self
.
tokenizer
)
req
.
tokenizer
=
self
.
tokenizer
if
recv_req
.
session_id
is
not
None
:
req
.
finished_reason
=
FINISH_ABORT
(
f
"Invalid request: session id
{
recv_req
.
session_id
}
does not exist"
)
self
.
waiting_queue
.
append
(
req
)
return
else
:
# Handle sessions
session
=
self
.
sessions
[
recv_req
.
session_id
]
req
,
new_session_id
=
session
.
create_req
(
recv_req
,
self
.
tokenizer
)
del
self
.
sessions
[
recv_req
.
session_id
]
self
.
sessions
[
new_session_id
]
=
session
if
isinstance
(
req
.
finished_reason
,
FINISH_ABORT
):
self
.
waiting_queue
.
append
(
req
)
return
# Image inputs
# Image inputs
if
recv_req
.
image_inputs
is
not
None
:
if
recv_req
.
image_inputs
is
not
None
:
...
@@ -1151,6 +1179,7 @@ class Scheduler:
...
@@ -1151,6 +1179,7 @@ class Scheduler:
output_skip_special_tokens
=
[]
output_skip_special_tokens
=
[]
output_spaces_between_special_tokens
=
[]
output_spaces_between_special_tokens
=
[]
output_no_stop_trim
=
[]
output_no_stop_trim
=
[]
output_session_ids
=
[]
else
:
# embedding or reward model
else
:
# embedding or reward model
output_embeddings
=
[]
output_embeddings
=
[]
...
@@ -1178,6 +1207,7 @@ class Scheduler:
...
@@ -1178,6 +1207,7 @@ class Scheduler:
req
.
sampling_params
.
spaces_between_special_tokens
req
.
sampling_params
.
spaces_between_special_tokens
)
)
output_no_stop_trim
.
append
(
req
.
sampling_params
.
no_stop_trim
)
output_no_stop_trim
.
append
(
req
.
sampling_params
.
no_stop_trim
)
output_session_ids
.
append
(
req
.
session_id
)
meta_info
=
{
meta_info
=
{
"prompt_tokens"
:
len
(
req
.
origin_input_ids
),
"prompt_tokens"
:
len
(
req
.
origin_input_ids
),
...
@@ -1228,6 +1258,7 @@ class Scheduler:
...
@@ -1228,6 +1258,7 @@ class Scheduler:
output_meta_info
,
output_meta_info
,
output_finished_reason
,
output_finished_reason
,
output_no_stop_trim
,
output_no_stop_trim
,
output_session_ids
,
)
)
)
)
else
:
# embedding or reward model
else
:
# embedding or reward model
...
@@ -1330,6 +1361,25 @@ class Scheduler:
...
@@ -1330,6 +1361,25 @@ class Scheduler:
)
)
logger
.
info
(
"Profiler is done"
)
logger
.
info
(
"Profiler is done"
)
def
open_session
(
self
,
recv_req
:
OpenSessionReqInput
)
->
str
:
# handle error
session_id
=
recv_req
.
session_id
if
session_id
in
self
.
sessions
:
logger
.
warning
(
f
"session id
{
session_id
}
already exist, cannot open."
)
else
:
self
.
sessions
[
session_id
]
=
Session
(
recv_req
.
capacity_of_str_len
,
session_id
)
return
session_id
def
close_session
(
self
,
recv_req
:
CloseSessionReqInput
):
# handle error
session_id
=
recv_req
.
session_id
if
session_id
not
in
self
.
sessions
:
logger
.
warning
(
f
"session id
{
session_id
}
does not exist, cannot delete."
)
else
:
del
self
.
sessions
[
session_id
]
def
run_scheduler_process
(
def
run_scheduler_process
(
server_args
:
ServerArgs
,
server_args
:
ServerArgs
,
...
...
python/sglang/srt/managers/session_controller.py
0 → 100644
View file @
5942dfc0
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import
copy
import
uuid
from
dataclasses
import
dataclass
from
typing
import
Optional
from
sglang.srt.managers.io_struct
import
TokenizedGenerateReqInput
from
sglang.srt.managers.schedule_batch
import
FINISH_ABORT
,
List
,
Req
class
Session
:
def
__init__
(
self
,
capacity_of_str_len
:
int
,
session_id
:
str
=
None
):
self
.
session_id
=
session_id
if
session_id
is
not
None
else
uuid
.
uuid4
().
hex
self
.
capacity_of_str_len
=
capacity_of_str_len
self
.
reqs
:
List
[
Req
]
=
[]
def
create_req
(
self
,
req
:
TokenizedGenerateReqInput
,
tokenizer
):
# renew session id
self
.
session_id
=
uuid
.
uuid4
().
hex
if
req
.
session_rid
is
not
None
:
while
len
(
self
.
reqs
)
>
0
:
if
self
.
reqs
[
-
1
].
rid
==
req
.
session_rid
:
break
self
.
reqs
=
self
.
reqs
[:
-
1
]
if
len
(
self
.
reqs
)
>
0
:
input_ids
=
(
self
.
reqs
[
-
1
].
origin_input_ids
+
self
.
reqs
[
-
1
].
output_ids
[
:
self
.
reqs
[
-
1
].
sampling_params
.
max_new_tokens
]
+
req
.
input_ids
)
else
:
input_ids
=
req
.
input_ids
new_req
=
Req
(
req
.
rid
,
None
,
input_ids
,
req
.
sampling_params
,
lora_path
=
req
.
lora_path
,
session_id
=
self
.
session_id
,
)
new_req
.
tokenizer
=
tokenizer
if
req
.
session_rid
is
not
None
and
len
(
self
.
reqs
)
==
0
:
new_req
.
finished_reason
=
FINISH_ABORT
(
f
"Invalid request: requested session rid
{
req
.
session_rid
}
does not exist in the session history"
)
else
:
self
.
reqs
.
append
(
new_req
)
return
new_req
,
self
.
session_id
python/sglang/srt/managers/tokenizer_manager.py
View file @
5942dfc0
...
@@ -23,6 +23,7 @@ import os
...
@@ -23,6 +23,7 @@ import os
import
signal
import
signal
import
sys
import
sys
import
time
import
time
import
uuid
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
fastapi
import
fastapi
...
@@ -42,11 +43,14 @@ from sglang.srt.managers.io_struct import (
...
@@ -42,11 +43,14 @@ from sglang.srt.managers.io_struct import (
BatchEmbeddingOut
,
BatchEmbeddingOut
,
BatchStrOut
,
BatchStrOut
,
BatchTokenIDOut
,
BatchTokenIDOut
,
CloseSessionReqInput
,
EmbeddingReqInput
,
EmbeddingReqInput
,
FlushCacheReq
,
FlushCacheReq
,
GenerateReqInput
,
GenerateReqInput
,
GetMemPoolSizeReq
,
GetMemPoolSizeReq
,
GetMemPoolSizeReqOutput
,
GetMemPoolSizeReqOutput
,
OpenSessionReqInput
,
OpenSessionReqOutput
,
ProfileReq
,
ProfileReq
,
TokenizedEmbeddingReqInput
,
TokenizedEmbeddingReqInput
,
TokenizedGenerateReqInput
,
TokenizedGenerateReqInput
,
...
@@ -146,6 +150,9 @@ class TokenizerManager:
...
@@ -146,6 +150,9 @@ class TokenizerManager:
self
.
model_update_lock
=
asyncio
.
Lock
()
self
.
model_update_lock
=
asyncio
.
Lock
()
self
.
model_update_result
=
None
self
.
model_update_result
=
None
# For session info
self
.
session_futures
=
{}
# session_id -> asyncio event
# Others
# Others
self
.
gracefully_exit
=
False
self
.
gracefully_exit
=
False
...
@@ -211,6 +218,8 @@ class TokenizerManager:
...
@@ -211,6 +218,8 @@ class TokenizerManager:
return_logprob
=
obj
.
return_logprob
return_logprob
=
obj
.
return_logprob
logprob_start_len
=
obj
.
logprob_start_len
logprob_start_len
=
obj
.
logprob_start_len
top_logprobs_num
=
obj
.
top_logprobs_num
top_logprobs_num
=
obj
.
top_logprobs_num
session_id
=
obj
.
session_id
session_rid
=
obj
.
session_rid
if
len
(
input_ids
)
>=
self
.
context_len
:
if
len
(
input_ids
)
>=
self
.
context_len
:
raise
ValueError
(
raise
ValueError
(
...
@@ -236,6 +245,8 @@ class TokenizerManager:
...
@@ -236,6 +245,8 @@ class TokenizerManager:
top_logprobs_num
,
top_logprobs_num
,
obj
.
stream
,
obj
.
stream
,
obj
.
lora_path
,
obj
.
lora_path
,
session_id
=
session_id
,
session_rid
=
session_rid
,
)
)
elif
isinstance
(
obj
,
EmbeddingReqInput
):
elif
isinstance
(
obj
,
EmbeddingReqInput
):
tokenized_obj
=
TokenizedEmbeddingReqInput
(
tokenized_obj
=
TokenizedEmbeddingReqInput
(
...
@@ -451,6 +462,26 @@ class TokenizerManager:
...
@@ -451,6 +462,26 @@ class TokenizerManager:
else
:
else
:
return
False
,
"Another update is in progress. Please try again later."
return
False
,
"Another update is in progress. Please try again later."
async
def
open_session
(
self
,
obj
:
OpenSessionReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
):
if
self
.
to_create_loop
:
self
.
create_handle_loop
()
session_id
=
uuid
.
uuid4
().
hex
obj
.
session_id
=
session_id
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
self
.
session_futures
[
session_id
]
=
asyncio
.
Future
()
session_id
=
await
self
.
session_futures
[
session_id
]
del
self
.
session_futures
[
session_id
]
return
session_id
async
def
close_session
(
self
,
obj
:
CloseSessionReqInput
,
request
:
Optional
[
fastapi
.
Request
]
=
None
):
assert
not
self
.
to_create_loop
,
"close session should not be the first request"
await
self
.
send_to_scheduler
.
send_pyobj
(
obj
)
def
create_abort_task
(
self
,
obj
:
GenerateReqInput
):
def
create_abort_task
(
self
,
obj
:
GenerateReqInput
):
# Abort the request if the client is disconnected.
# Abort the request if the client is disconnected.
async
def
abort_request
():
async
def
abort_request
():
...
@@ -521,6 +552,11 @@ class TokenizerManager:
...
@@ -521,6 +552,11 @@ class TokenizerManager:
if
len
(
self
.
mem_pool_size_tmp
)
==
self
.
server_args
.
dp_size
:
if
len
(
self
.
mem_pool_size_tmp
)
==
self
.
server_args
.
dp_size
:
self
.
mem_pool_size
.
set_result
(
self
.
mem_pool_size_tmp
)
self
.
mem_pool_size
.
set_result
(
self
.
mem_pool_size_tmp
)
continue
continue
elif
isinstance
(
recv_obj
,
OpenSessionReqOutput
):
self
.
session_futures
[
recv_obj
.
session_id
].
set_result
(
recv_obj
.
session_id
)
continue
assert
isinstance
(
assert
isinstance
(
recv_obj
,
(
BatchStrOut
,
BatchEmbeddingOut
,
BatchTokenIDOut
)
recv_obj
,
(
BatchStrOut
,
BatchEmbeddingOut
,
BatchTokenIDOut
)
...
@@ -536,11 +572,13 @@ class TokenizerManager:
...
@@ -536,11 +572,13 @@ class TokenizerManager:
out_dict
=
{
out_dict
=
{
"text"
:
recv_obj
.
output_strs
[
i
],
"text"
:
recv_obj
.
output_strs
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
"session_id"
:
recv_obj
.
session_ids
[
i
],
}
}
elif
isinstance
(
recv_obj
,
BatchTokenIDOut
):
elif
isinstance
(
recv_obj
,
BatchTokenIDOut
):
out_dict
=
{
out_dict
=
{
"token_ids"
:
recv_obj
.
output_ids
[
i
],
"token_ids"
:
recv_obj
.
output_ids
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
"session_id"
:
recv_obj
.
session_ids
[
i
],
}
}
else
:
else
:
assert
isinstance
(
recv_obj
,
BatchEmbeddingOut
)
assert
isinstance
(
recv_obj
,
BatchEmbeddingOut
)
...
...
python/sglang/srt/server.py
View file @
5942dfc0
...
@@ -50,8 +50,10 @@ from sglang.srt.managers.data_parallel_controller import (
...
@@ -50,8 +50,10 @@ from sglang.srt.managers.data_parallel_controller import (
)
)
from
sglang.srt.managers.detokenizer_manager
import
run_detokenizer_process
from
sglang.srt.managers.detokenizer_manager
import
run_detokenizer_process
from
sglang.srt.managers.io_struct
import
(
from
sglang.srt.managers.io_struct
import
(
CloseSessionReqInput
,
EmbeddingReqInput
,
EmbeddingReqInput
,
GenerateReqInput
,
GenerateReqInput
,
OpenSessionReqInput
,
UpdateWeightReqInput
,
UpdateWeightReqInput
,
)
)
from
sglang.srt.managers.scheduler
import
run_scheduler_process
from
sglang.srt.managers.scheduler
import
run_scheduler_process
...
@@ -215,6 +217,30 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request):
...
@@ -215,6 +217,30 @@ async def update_weights(obj: UpdateWeightReqInput, request: Request):
)
)
@
app
.
api_route
(
"/open_session"
,
methods
=
[
"GET"
,
"POST"
])
async
def
open_session
(
obj
:
OpenSessionReqInput
,
request
:
Request
):
"""Open a session, and return its unique session id."""
try
:
session_id
=
await
tokenizer_manager
.
open_session
(
obj
,
request
)
return
session_id
except
Exception
as
e
:
return
ORJSONResponse
(
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
)
@
app
.
api_route
(
"/close_session"
,
methods
=
[
"GET"
,
"POST"
])
async
def
close_session
(
obj
:
CloseSessionReqInput
,
request
:
Request
):
"""Close the session"""
try
:
await
tokenizer_manager
.
close_session
(
obj
,
request
)
return
Response
(
status_code
=
200
)
except
Exception
as
e
:
return
ORJSONResponse
(
{
"error"
:
{
"message"
:
str
(
e
)}},
status_code
=
HTTPStatus
.
BAD_REQUEST
)
@
time_func_latency
@
time_func_latency
async
def
generate_request
(
obj
:
GenerateReqInput
,
request
:
Request
):
async
def
generate_request
(
obj
:
GenerateReqInput
,
request
:
Request
):
"""Handle a generate request."""
"""Handle a generate request."""
...
...
test/srt/test_session_id.py
0 → 100644
View file @
5942dfc0
"""
Copyright 2023-2024 SGLang Team
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
# FIXME: Make it a CI test
import
requests
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
url
=
"http://localhost:30000"
# Open a session
response
=
requests
.
post
(
url
+
"/open_session"
,
json
=
{
"capacity_of_str_len"
:
1000
},
)
session_id
=
response
.
json
()
print
(
"session_id"
,
session_id
,
"
\n
"
)
# Prefill only
prompt
=
"chunk 1"
response
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
"text"
:
prompt
,
"session_id"
:
session_id
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
0
,
},
},
)
print
(
response
.
json
(),
"
\n
"
)
session_id
=
response
.
json
()[
"session_id"
]
# Generate
prompt
=
"Chunk 2"
response
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
"text"
:
prompt
,
"session_id"
:
session_id
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
16
,
},
},
)
print
(
response
.
json
(),
"
\n
"
)
session_id
=
response
.
json
()[
"session_id"
]
rid
=
response
.
json
()[
"meta_info"
][
"id"
]
# Generate
prompt
=
"Chunk 3"
response
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
"text"
:
prompt
,
"session_id"
:
session_id
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
2
,
},
},
)
print
(
response
.
json
(),
"
\n
"
)
session_id
=
response
.
json
()[
"session_id"
]
rid_to_del
=
response
.
json
()[
"meta_info"
][
"id"
]
# Interrupt and re-generate
prompt
=
"Chunk 4"
response
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
"text"
:
prompt
,
"session_id"
:
session_id
,
"session_rid"
:
rid
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
16
,
},
},
)
print
(
response
.
json
(),
"
\n
"
)
session_id
=
response
.
json
()[
"session_id"
]
# Query a session based on a deleted request, should see finish reason abort
prompt
=
"Chunk 4"
response
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
"text"
:
prompt
,
"session_id"
:
session_id
,
"session_rid"
:
rid_to_del
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
16
,
},
},
)
print
(
response
.
json
(),
"
\n
"
)
# Close session
ret
=
requests
.
post
(
url
+
"/close_session"
,
json
=
{
"session_id"
:
session_id
},
)
print
(
ret
,
"
\n
"
)
# Query a deleted session, should see finish reason abort
prompt
=
"chunk 1"
response
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
"text"
:
prompt
,
"session_id"
:
session_id
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
0
,
},
},
)
print
(
response
.
json
(),
"
\n
"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment