Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
08ab2a16
Unverified
Commit
08ab2a16
authored
Jan 15, 2024
by
Liangsheng Yin
Committed by
GitHub
Jan 15, 2024
Browse files
Json Decode && Mutl-Turns (#4)
parent
f652494d
Changes
27
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
51 additions
and
8 deletions
+51
-8
python/sglang/srt/managers/router/manager.py
python/sglang/srt/managers/router/manager.py
+10
-2
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+15
-2
python/sglang/srt/managers/router/scheduler.py
python/sglang/srt/managers/router/scheduler.py
+1
-1
python/sglang/srt/sampling_params.py
python/sglang/srt/sampling_params.py
+2
-2
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+2
-0
python/sglang/test/test_utils.py
python/sglang/test/test_utils.py
+20
-0
python/sglang/utils.py
python/sglang/utils.py
+1
-1
No files found.
python/sglang/srt/managers/router/manager.py
View file @
08ab2a16
import
asyncio
import
asyncio
import
logging
import
logging
from
typing
import
List
,
Tuple
import
uvloop
import
uvloop
import
zmq
import
zmq
...
@@ -8,6 +7,7 @@ import zmq.asyncio
...
@@ -8,6 +7,7 @@ import zmq.asyncio
from
sglang.srt.managers.router.model_rpc
import
ModelRpcClient
from
sglang.srt.managers.router.model_rpc
import
ModelRpcClient
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
get_exception_traceback
from
sglang.srt.utils
import
get_exception_traceback
from
sglang.srt.backend_config
import
GLOBAL_BACKEND_CONFIG
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
asyncio
.
set_event_loop_policy
(
uvloop
.
EventLoopPolicy
())
...
@@ -28,6 +28,9 @@ class RouterManager:
...
@@ -28,6 +28,9 @@ class RouterManager:
self
.
model_client
=
model_client
self
.
model_client
=
model_client
self
.
recv_reqs
=
[]
self
.
recv_reqs
=
[]
# Init Some Configs
self
.
extend_dependency_time
=
GLOBAL_BACKEND_CONFIG
.
extend_dependency_time
async
def
loop_for_forward
(
self
):
async
def
loop_for_forward
(
self
):
while
True
:
while
True
:
next_step_input
=
list
(
self
.
recv_reqs
)
next_step_input
=
list
(
self
.
recv_reqs
)
...
@@ -37,7 +40,12 @@ class RouterManager:
...
@@ -37,7 +40,12 @@ class RouterManager:
for
obj
in
out_pyobjs
:
for
obj
in
out_pyobjs
:
self
.
send_to_detokenizer
.
send_pyobj
(
obj
)
self
.
send_to_detokenizer
.
send_pyobj
(
obj
)
# await for a while to accept input requests
# async sleep for recving the subsequent request, and avoiding cache miss
if
len
(
out_pyobjs
)
!=
0
:
has_finished
=
any
([
obj
.
finished
for
obj
in
out_pyobjs
])
if
has_finished
:
await
asyncio
.
sleep
(
self
.
extend_dependency_time
)
await
asyncio
.
sleep
(
0.001
)
await
asyncio
.
sleep
(
0.001
)
async
def
loop_for_recv_requests
(
self
):
async
def
loop_for_recv_requests
(
self
):
...
...
python/sglang/srt/managers/router/model_rpc.py
View file @
08ab2a16
...
@@ -19,7 +19,6 @@ from sglang.srt.managers.router.model_runner import ModelRunner
...
@@ -19,7 +19,6 @@ from sglang.srt.managers.router.model_runner import ModelRunner
from
sglang.srt.managers.router.radix_cache
import
RadixCache
from
sglang.srt.managers.router.radix_cache
import
RadixCache
from
sglang.srt.managers.router.scheduler
import
Scheduler
from
sglang.srt.managers.router.scheduler
import
Scheduler
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.model_config
import
ModelConfig
from
sglang.srt.sampling_params
import
SamplingParams
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.server_args
import
PortArgs
,
ServerArgs
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
get_exception_traceback
,
get_exception_traceback
,
...
@@ -158,6 +157,18 @@ class ModelRpcServer(rpyc.Service):
...
@@ -158,6 +157,18 @@ class ModelRpcServer(rpyc.Service):
if
self
.
running_batch
.
is_empty
():
if
self
.
running_batch
.
is_empty
():
self
.
running_batch
=
None
self
.
running_batch
=
None
break
break
else
:
# check the available size
available_size
=
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
)
if
available_size
!=
self
.
max_total_num_token
:
logger
.
warning
(
"Warning: "
f
"available_size=
{
available_size
}
, max_total_num_token=
{
self
.
max_total_num_token
}
\n
"
"KV cache pool leak detected!"
)
if
self
.
running_batch
is
not
None
and
self
.
tp_rank
==
0
:
if
self
.
running_batch
is
not
None
and
self
.
tp_rank
==
0
:
if
self
.
decode_forward_ct
>=
20
:
if
self
.
decode_forward_ct
>=
20
:
...
@@ -408,7 +419,9 @@ class ModelRpcServer(rpyc.Service):
...
@@ -408,7 +419,9 @@ class ModelRpcServer(rpyc.Service):
token_ids
=
tuple
(
req
.
input_ids
+
req
.
output_ids
)
token_ids
=
tuple
(
req
.
input_ids
+
req
.
output_ids
)
seq_len
=
len
(
token_ids
)
-
1
seq_len
=
len
(
token_ids
)
-
1
indices
=
self
.
req_to_token_pool
.
req_to_token
[
req_pool_idx
,
:
seq_len
]
indices
=
self
.
req_to_token_pool
.
req_to_token
[
req_pool_idx
,
:
seq_len
]
prefix_len
=
self
.
tree_cache
.
insert
(
token_ids
,
indices
.
clone
())
prefix_len
=
self
.
tree_cache
.
insert
(
token_ids
[:
seq_len
],
indices
.
clone
()
)
self
.
token_to_kv_pool
.
free
(
indices
[:
prefix_len
])
self
.
token_to_kv_pool
.
free
(
indices
[:
prefix_len
])
self
.
req_to_token_pool
.
free
(
req_pool_idx
)
self
.
req_to_token_pool
.
free
(
req_pool_idx
)
...
...
python/sglang/srt/managers/router/scheduler.py
View file @
08ab2a16
...
@@ -18,7 +18,7 @@ class Scheduler:
...
@@ -18,7 +18,7 @@ class Scheduler:
self
.
tree_cache
=
tree_cache
self
.
tree_cache
=
tree_cache
def
new_token_estimation_ratio
(
self
):
def
new_token_estimation_ratio
(
self
):
return
0.
4
if
self
.
schedule_heuristic
!=
"fcfs"
else
0.
5
return
0.
5
if
self
.
schedule_heuristic
!=
"fcfs"
else
0.
6
def
get_priority_queue
(
self
,
forward_queue
):
def
get_priority_queue
(
self
,
forward_queue
):
if
self
.
schedule_heuristic
==
"lpm"
:
if
self
.
schedule_heuristic
==
"lpm"
:
...
...
python/sglang/srt/sampling_params.py
View file @
08ab2a16
...
@@ -7,13 +7,13 @@ _SAMPLING_EPS = 1e-6
...
@@ -7,13 +7,13 @@ _SAMPLING_EPS = 1e-6
class
SamplingParams
:
class
SamplingParams
:
def
__init__
(
def
__init__
(
self
,
self
,
max_new_tokens
:
int
=
16
,
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
temperature
:
float
=
1.0
,
temperature
:
float
=
1.0
,
top_p
:
float
=
1.0
,
top_p
:
float
=
1.0
,
top_k
:
int
=
-
1
,
top_k
:
int
=
-
1
,
frequency_penalty
:
float
=
0.0
,
frequency_penalty
:
float
=
0.0
,
presence_penalty
:
float
=
0.0
,
presence_penalty
:
float
=
0.0
,
stop
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
,
max_new_tokens
:
int
=
16
,
ignore_eos
:
bool
=
False
,
ignore_eos
:
bool
=
False
,
skip_special_tokens
:
bool
=
True
,
skip_special_tokens
:
bool
=
True
,
dtype
:
Optional
[
str
]
=
None
,
dtype
:
Optional
[
str
]
=
None
,
...
...
python/sglang/srt/server_args.py
View file @
08ab2a16
...
@@ -24,6 +24,8 @@ class ServerArgs:
...
@@ -24,6 +24,8 @@ class ServerArgs:
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
self
.
tokenizer_path
is
None
:
if
self
.
tokenizer_path
is
None
:
self
.
tokenizer_path
=
self
.
model_path
self
.
tokenizer_path
=
self
.
model_path
if
self
.
tp_size
>
1
:
self
.
mem_fraction_static
=
0.8
@
staticmethod
@
staticmethod
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
def
add_cli_args
(
parser
:
argparse
.
ArgumentParser
):
...
...
python/sglang/test/test_utils.py
View file @
08ab2a16
...
@@ -38,6 +38,26 @@ def call_generate_vllm(prompt, temperature, max_tokens, stop, url, n=1):
...
@@ -38,6 +38,26 @@ def call_generate_vllm(prompt, temperature, max_tokens, stop, url, n=1):
return
pred
return
pred
def
call_generate_outlines
(
prompt
,
temperature
,
max_tokens
,
url
,
stop
=
[],
regex
=
None
,
n
=
1
):
data
=
{
"prompt"
:
prompt
,
"temperature"
:
temperature
,
"max_tokens"
:
max_tokens
,
"stop"
:
stop
,
"regex"
:
regex
,
"n"
:
n
,
}
res
=
requests
.
post
(
url
,
json
=
data
)
assert
res
.
status_code
==
200
if
n
==
1
:
pred
=
res
.
json
()[
"text"
][
0
][
len
(
prompt
)
:]
else
:
pred
=
[
x
[
len
(
prompt
)
:]
for
x
in
res
.
json
()[
"text"
]]
return
pred
def
call_generate_srt_raw
(
prompt
,
temperature
,
max_tokens
,
stop
,
url
):
def
call_generate_srt_raw
(
prompt
,
temperature
,
max_tokens
,
stop
,
url
):
data
=
{
data
=
{
"text"
:
prompt
,
"text"
:
prompt
,
...
...
python/sglang/utils.py
View file @
08ab2a16
...
@@ -67,7 +67,7 @@ def dump_state_text(filename, states, mode="w"):
...
@@ -67,7 +67,7 @@ def dump_state_text(filename, states, mode="w"):
if
isinstance
(
s
,
str
):
if
isinstance
(
s
,
str
):
pass
pass
elif
isinstance
(
s
,
ProgramState
):
elif
isinstance
(
s
,
ProgramState
):
s
=
s
.
text
()
.
strip
()
s
=
s
.
text
()
else
:
else
:
s
=
str
(
s
)
s
=
str
(
s
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment