Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e1e595d7
Unverified
Commit
e1e595d7
authored
Nov 25, 2024
by
Ying Sheng
Committed by
GitHub
Nov 25, 2024
Browse files
[feat] Refactor session control interface and add CI (#2173)
parent
5ada33ff
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
180 additions
and
154 deletions
+180
-154
python/sglang/srt/managers/detokenizer_manager.py
python/sglang/srt/managers/detokenizer_manager.py
+0
-1
python/sglang/srt/managers/io_struct.py
python/sglang/srt/managers/io_struct.py
+5
-8
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+1
-6
python/sglang/srt/managers/session_controller.py
python/sglang/srt/managers/session_controller.py
+3
-3
python/sglang/srt/managers/tokenizer_manager.py
python/sglang/srt/managers/tokenizer_manager.py
+2
-4
scripts/playground/test_session_id.py
scripts/playground/test_session_id.py
+0
-132
test/srt/run_suite.py
test/srt/run_suite.py
+1
-0
test/srt/test_session_control.py
test/srt/test_session_control.py
+168
-0
No files found.
python/sglang/srt/managers/detokenizer_manager.py
View file @
e1e595d7
...
@@ -173,7 +173,6 @@ class DetokenizerManager:
...
@@ -173,7 +173,6 @@ class DetokenizerManager:
output_strs
=
output_strs
,
output_strs
=
output_strs
,
meta_info
=
recv_obj
.
meta_info
,
meta_info
=
recv_obj
.
meta_info
,
finished_reason
=
recv_obj
.
finished_reason
,
finished_reason
=
recv_obj
.
finished_reason
,
session_ids
=
recv_obj
.
session_ids
,
)
)
)
)
...
...
python/sglang/srt/managers/io_struct.py
View file @
e1e595d7
...
@@ -19,7 +19,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
...
@@ -19,7 +19,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
import
uuid
import
uuid
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
enum
import
Enum
from
enum
import
Enum
from
typing
import
Dict
,
List
,
Optional
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
sglang.srt.managers.schedule_batch
import
BaseFinishReason
from
sglang.srt.managers.schedule_batch
import
BaseFinishReason
from
sglang.srt.sampling.sampling_params
import
SamplingParams
from
sglang.srt.sampling.sampling_params
import
SamplingParams
...
@@ -55,8 +55,9 @@ class GenerateReqInput:
...
@@ -55,8 +55,9 @@ class GenerateReqInput:
lora_path
:
Optional
[
Union
[
List
[
Optional
[
str
]],
Optional
[
str
]]]
=
None
lora_path
:
Optional
[
Union
[
List
[
Optional
[
str
]],
Optional
[
str
]]]
=
None
# Session id info for continual prompting
# Session id info for continual prompting
session_id
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
session
:
Optional
[
session_rid
:
Optional
[
Union
[
List
[
str
],
str
]]
=
None
Union
[
List
[
Tuple
[
str
,
Optional
[
str
]]],
Tuple
[
str
,
Optional
[
str
]]]
]
=
None
def
normalize_batch_and_arguments
(
self
):
def
normalize_batch_and_arguments
(
self
):
if
(
self
.
text
is
None
and
self
.
input_ids
is
None
)
or
(
if
(
self
.
text
is
None
and
self
.
input_ids
is
None
)
or
(
...
@@ -203,7 +204,7 @@ class TokenizedGenerateReqInput:
...
@@ -203,7 +204,7 @@ class TokenizedGenerateReqInput:
lora_path
:
Optional
[
str
]
=
None
# None means just use the base model
lora_path
:
Optional
[
str
]
=
None
# None means just use the base model
# Session id info for continual prompting
# Session id info for continual prompting
session_id
:
Optional
[
int
]
=
None
session_id
:
Optional
[
str
]
=
None
session_rid
:
Optional
[
str
]
=
None
session_rid
:
Optional
[
str
]
=
None
...
@@ -299,8 +300,6 @@ class BatchTokenIDOut:
...
@@ -299,8 +300,6 @@ class BatchTokenIDOut:
meta_info
:
List
[
Dict
]
meta_info
:
List
[
Dict
]
finished_reason
:
List
[
BaseFinishReason
]
finished_reason
:
List
[
BaseFinishReason
]
no_stop_trim
:
List
[
bool
]
no_stop_trim
:
List
[
bool
]
# The updated session unique id
session_ids
:
List
[
str
]
@
dataclass
@
dataclass
...
@@ -313,8 +312,6 @@ class BatchStrOut:
...
@@ -313,8 +312,6 @@ class BatchStrOut:
meta_info
:
List
[
Dict
]
meta_info
:
List
[
Dict
]
# The finish reason
# The finish reason
finished_reason
:
List
[
BaseFinishReason
]
finished_reason
:
List
[
BaseFinishReason
]
# The update session unique id
session_ids
:
List
[
str
]
@
dataclass
@
dataclass
...
...
python/sglang/srt/managers/scheduler.py
View file @
e1e595d7
...
@@ -542,9 +542,7 @@ class Scheduler:
...
@@ -542,9 +542,7 @@ class Scheduler:
else
:
else
:
# Handle sessions
# Handle sessions
session
=
self
.
sessions
[
recv_req
.
session_id
]
session
=
self
.
sessions
[
recv_req
.
session_id
]
req
,
new_session_id
=
session
.
create_req
(
recv_req
,
self
.
tokenizer
)
req
=
session
.
create_req
(
recv_req
,
self
.
tokenizer
)
del
self
.
sessions
[
recv_req
.
session_id
]
self
.
sessions
[
new_session_id
]
=
session
if
isinstance
(
req
.
finished_reason
,
FINISH_ABORT
):
if
isinstance
(
req
.
finished_reason
,
FINISH_ABORT
):
self
.
waiting_queue
.
append
(
req
)
self
.
waiting_queue
.
append
(
req
)
return
return
...
@@ -1188,7 +1186,6 @@ class Scheduler:
...
@@ -1188,7 +1186,6 @@ class Scheduler:
output_skip_special_tokens
=
[]
output_skip_special_tokens
=
[]
output_spaces_between_special_tokens
=
[]
output_spaces_between_special_tokens
=
[]
output_no_stop_trim
=
[]
output_no_stop_trim
=
[]
output_session_ids
=
[]
else
:
# embedding or reward model
else
:
# embedding or reward model
output_embeddings
=
[]
output_embeddings
=
[]
...
@@ -1216,7 +1213,6 @@ class Scheduler:
...
@@ -1216,7 +1213,6 @@ class Scheduler:
req
.
sampling_params
.
spaces_between_special_tokens
req
.
sampling_params
.
spaces_between_special_tokens
)
)
output_no_stop_trim
.
append
(
req
.
sampling_params
.
no_stop_trim
)
output_no_stop_trim
.
append
(
req
.
sampling_params
.
no_stop_trim
)
output_session_ids
.
append
(
req
.
session_id
)
meta_info
=
{
meta_info
=
{
"prompt_tokens"
:
len
(
req
.
origin_input_ids
),
"prompt_tokens"
:
len
(
req
.
origin_input_ids
),
...
@@ -1267,7 +1263,6 @@ class Scheduler:
...
@@ -1267,7 +1263,6 @@ class Scheduler:
output_meta_info
,
output_meta_info
,
output_finished_reason
,
output_finished_reason
,
output_no_stop_trim
,
output_no_stop_trim
,
output_session_ids
,
)
)
)
)
else
:
# embedding or reward model
else
:
# embedding or reward model
...
...
python/sglang/srt/managers/session_controller.py
View file @
e1e595d7
...
@@ -26,13 +26,13 @@ class Session:
...
@@ -26,13 +26,13 @@ class Session:
self
.
reqs
:
List
[
Req
]
=
[]
self
.
reqs
:
List
[
Req
]
=
[]
def
create_req
(
self
,
req
:
TokenizedGenerateReqInput
,
tokenizer
):
def
create_req
(
self
,
req
:
TokenizedGenerateReqInput
,
tokenizer
):
# renew session id
self
.
session_id
=
uuid
.
uuid4
().
hex
if
req
.
session_rid
is
not
None
:
if
req
.
session_rid
is
not
None
:
while
len
(
self
.
reqs
)
>
0
:
while
len
(
self
.
reqs
)
>
0
:
if
self
.
reqs
[
-
1
].
rid
==
req
.
session_rid
:
if
self
.
reqs
[
-
1
].
rid
==
req
.
session_rid
:
break
break
self
.
reqs
=
self
.
reqs
[:
-
1
]
self
.
reqs
=
self
.
reqs
[:
-
1
]
else
:
self
.
reqs
=
[]
if
len
(
self
.
reqs
)
>
0
:
if
len
(
self
.
reqs
)
>
0
:
input_ids
=
(
input_ids
=
(
self
.
reqs
[
-
1
].
origin_input_ids
self
.
reqs
[
-
1
].
origin_input_ids
...
@@ -58,4 +58,4 @@ class Session:
...
@@ -58,4 +58,4 @@ class Session:
)
)
else
:
else
:
self
.
reqs
.
append
(
new_req
)
self
.
reqs
.
append
(
new_req
)
return
new_req
,
self
.
session_id
return
new_req
python/sglang/srt/managers/tokenizer_manager.py
View file @
e1e595d7
...
@@ -216,8 +216,8 @@ class TokenizerManager:
...
@@ -216,8 +216,8 @@ class TokenizerManager:
return_logprob
=
obj
.
return_logprob
return_logprob
=
obj
.
return_logprob
logprob_start_len
=
obj
.
logprob_start_len
logprob_start_len
=
obj
.
logprob_start_len
top_logprobs_num
=
obj
.
top_logprobs_num
top_logprobs_num
=
obj
.
top_logprobs_num
session_id
=
obj
.
session
_id
session_id
=
obj
.
session
[
0
]
if
obj
.
session
else
None
session_rid
=
obj
.
session
_rid
session_rid
=
obj
.
session
[
1
]
if
obj
.
session
else
None
if
len
(
input_ids
)
>=
self
.
context_len
:
if
len
(
input_ids
)
>=
self
.
context_len
:
raise
ValueError
(
raise
ValueError
(
...
@@ -570,13 +570,11 @@ class TokenizerManager:
...
@@ -570,13 +570,11 @@ class TokenizerManager:
out_dict
=
{
out_dict
=
{
"text"
:
recv_obj
.
output_strs
[
i
],
"text"
:
recv_obj
.
output_strs
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
"session_id"
:
recv_obj
.
session_ids
[
i
],
}
}
elif
isinstance
(
recv_obj
,
BatchTokenIDOut
):
elif
isinstance
(
recv_obj
,
BatchTokenIDOut
):
out_dict
=
{
out_dict
=
{
"token_ids"
:
recv_obj
.
output_ids
[
i
],
"token_ids"
:
recv_obj
.
output_ids
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
"meta_info"
:
recv_obj
.
meta_info
[
i
],
"session_id"
:
recv_obj
.
session_ids
[
i
],
}
}
else
:
else
:
assert
isinstance
(
recv_obj
,
BatchEmbeddingOut
)
assert
isinstance
(
recv_obj
,
BatchEmbeddingOut
)
...
...
scripts/playground/test_session_id.py
deleted
100644 → 0
View file @
5ada33ff
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# FIXME: Make it a CI test
import
requests
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
url
=
"http://localhost:30000"
# Open a session
response
=
requests
.
post
(
url
+
"/open_session"
,
json
=
{
"capacity_of_str_len"
:
1000
},
)
session_id
=
response
.
json
()
print
(
"session_id"
,
session_id
,
"
\n
"
)
# Prefill only
prompt
=
"chunk 1"
response
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
"text"
:
prompt
,
"session_id"
:
session_id
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
0
,
},
},
)
print
(
response
.
json
(),
"
\n
"
)
session_id
=
response
.
json
()[
"session_id"
]
# Generate
prompt
=
"Chunk 2"
response
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
"text"
:
prompt
,
"session_id"
:
session_id
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
16
,
},
},
)
print
(
response
.
json
(),
"
\n
"
)
session_id
=
response
.
json
()[
"session_id"
]
rid
=
response
.
json
()[
"meta_info"
][
"id"
]
# Generate
prompt
=
"Chunk 3"
response
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
"text"
:
prompt
,
"session_id"
:
session_id
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
2
,
},
},
)
print
(
response
.
json
(),
"
\n
"
)
session_id
=
response
.
json
()[
"session_id"
]
rid_to_del
=
response
.
json
()[
"meta_info"
][
"id"
]
# Interrupt and re-generate
prompt
=
"Chunk 4"
response
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
"text"
:
prompt
,
"session_id"
:
session_id
,
"session_rid"
:
rid
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
16
,
},
},
)
print
(
response
.
json
(),
"
\n
"
)
session_id
=
response
.
json
()[
"session_id"
]
# Query a session based on a deleted request, should see finish reason abort
prompt
=
"Chunk 4"
response
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
"text"
:
prompt
,
"session_id"
:
session_id
,
"session_rid"
:
rid_to_del
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
16
,
},
},
)
print
(
response
.
json
(),
"
\n
"
)
# Close session
ret
=
requests
.
post
(
url
+
"/close_session"
,
json
=
{
"session_id"
:
session_id
},
)
print
(
ret
,
"
\n
"
)
# Query a deleted session, should see finish reason abort
prompt
=
"chunk 1"
response
=
requests
.
post
(
url
+
"/generate"
,
json
=
{
"text"
:
prompt
,
"session_id"
:
session_id
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
0
,
},
},
)
print
(
response
.
json
(),
"
\n
"
)
test/srt/run_suite.py
View file @
e1e595d7
...
@@ -34,6 +34,7 @@ suites = {
...
@@ -34,6 +34,7 @@ suites = {
"test_triton_attention_backend.py"
,
"test_triton_attention_backend.py"
,
"test_update_weights.py"
,
"test_update_weights.py"
,
"test_vision_openai_server.py"
,
"test_vision_openai_server.py"
,
"test_session_control.py"
,
],
],
"sampling/penaltylib"
:
glob
.
glob
(
"sampling/penaltylib"
:
glob
.
glob
(
"sampling/penaltylib/**/test_*.py"
,
recursive
=
True
"sampling/penaltylib/**/test_*.py"
,
recursive
=
True
...
...
test/srt/test_session_control.py
0 → 100644
View file @
e1e595d7
"""
Usage:
python3 -m unittest test_session_control.TestSessionControl.test_session_control
python3 -m unittest test_session_control.TestSessionControl.test_session_control_vlm
"""
import
unittest
import
requests
from
sglang.srt.hf_transformers_utils
import
get_tokenizer
from
sglang.srt.utils
import
kill_child_process
from
sglang.test.test_utils
import
(
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
,
DEFAULT_URL_FOR_TEST
,
popen_launch_server
,
)
class
TestSessionControl
(
unittest
.
TestCase
):
@
classmethod
def
setUpClass
(
cls
):
cls
.
model
=
DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls
.
base_url
=
DEFAULT_URL_FOR_TEST
cls
.
process
=
popen_launch_server
(
cls
.
model
,
cls
.
base_url
,
timeout
=
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
@
classmethod
def
tearDownClass
(
cls
):
kill_child_process
(
cls
.
process
.
pid
,
include_self
=
True
)
def
test_session_control
(
self
):
chunks
=
[
"Let me tell you something about France."
,
"The capital of France is"
,
"A brief history about that city is"
,
"To plan a travel, the budget is"
,
]
tokenizer
=
get_tokenizer
(
self
.
model
)
chunks_ids
=
[
tokenizer
.
encode
(
x
)
for
x
in
chunks
]
# 1. using session control
session_id
=
requests
.
post
(
self
.
base_url
+
"/open_session"
,
json
=
{
"capacity_of_str_len"
:
1000
},
).
json
()
rid
=
None
first_rid
=
None
outputs_from_session
=
[]
for
i
,
chunk_ids
in
enumerate
(
chunks_ids
):
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"input_ids"
:
chunk_ids
,
"session"
:
[
session_id
,
rid
],
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
(
16
if
i
>
0
else
0
),
# prefill only for the first chunk
},
},
).
json
()
rid
=
response
[
"meta_info"
][
"id"
]
if
i
==
0
:
first_rid
=
rid
if
i
>
0
:
outputs_from_session
.
append
(
response
[
"text"
])
# backtrack to the first request and regenerate
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"input_ids"
:
chunks_ids
[
-
1
],
"session"
:
[
session_id
,
first_rid
],
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
16
,
},
},
).
json
()
outputs_from_session
.
append
(
response
[
"text"
])
# query with a non-existing rid (the last one should be disappeared becuase of backtrack), should see abort
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"input_ids"
:
chunks_ids
[
-
1
],
"session"
:
[
session_id
,
rid
],
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
16
,
},
},
).
json
()
assert
response
[
"meta_info"
][
"finish_reason"
][
"type"
]
==
"abort"
ret
=
requests
.
post
(
self
.
base_url
+
"/close_session"
,
json
=
{
"session_id"
:
session_id
},
)
assert
ret
.
status_code
==
200
# send a request to a closed session, should see abort
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"input_ids"
:
chunks_ids
[
-
1
],
"session"
:
[
session_id
,
first_rid
],
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
16
,
},
},
).
json
()
assert
response
[
"meta_info"
][
"finish_reason"
][
"type"
]
==
"abort"
# 2. not use session control
input_ids_first_req
=
None
input_ids
=
[]
outputs_normal
=
[]
for
i
,
chunk_ids
in
enumerate
(
chunks_ids
):
input_ids
+=
chunk_ids
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"input_ids"
:
input_ids
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
(
16
if
i
>
0
else
0
),
# prefill only for the first chunk
},
},
).
json
()
if
i
>
0
:
input_ids
+=
tokenizer
.
encode
(
response
[
"text"
])[
1
:
]
# drop the bos token
outputs_normal
.
append
(
response
[
"text"
])
if
i
==
0
:
input_ids_first_req
=
input_ids
.
copy
()
input_ids_first_req
+=
chunks_ids
[
-
1
]
response
=
requests
.
post
(
self
.
base_url
+
"/generate"
,
json
=
{
"input_ids"
:
input_ids_first_req
,
"sampling_params"
:
{
"temperature"
:
0
,
"max_new_tokens"
:
16
,
},
},
).
json
()
outputs_normal
.
append
(
response
[
"text"
])
print
(
"outputs from chunked queries with session control:"
)
print
(
outputs_from_session
)
print
(
"outputs from normal queries:"
)
print
(
outputs_normal
)
assert
outputs_from_session
==
outputs_normal
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment