Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
71b54eea
Unverified
Commit
71b54eea
authored
Jan 30, 2024
by
Cody Yu
Committed by
GitHub
Jan 30, 2024
Browse files
Add cache metrics (#119)
parent
74b3bfaa
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
87 additions
and
27 deletions
+87
-27
python/sglang/srt/constrained/base_cache.py
python/sglang/srt/constrained/base_cache.py
+50
-0
python/sglang/srt/constrained/fast_forward.py
python/sglang/srt/constrained/fast_forward.py
+5
-7
python/sglang/srt/constrained/fsm_cache.py
python/sglang/srt/constrained/fsm_cache.py
+5
-12
python/sglang/srt/managers/router/infer_batch.py
python/sglang/srt/managers/router/infer_batch.py
+5
-1
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+22
-7
No files found.
python/sglang/srt/constrained/base_cache.py
0 → 100644
View file @
71b54eea
"""Base cache class."""
import
time
class
BaseCache
:
def
__init__
(
self
,
enable
=
True
):
self
.
enable
=
enable
self
.
reset
()
def
reset
(
self
):
self
.
cache
=
{}
self
.
metrics
=
{
"total"
:
0
,
"hit"
:
0
,
"avg_init_time"
:
0
}
def
query
(
self
,
key
):
def
_init_with_timer
(
key
):
start
=
time
.
monotonic
()
val
=
self
.
init_value
(
key
)
init_time
=
time
.
monotonic
()
-
start
curr_total
=
self
.
metrics
[
"total"
]
new_total
=
curr_total
+
1
# Update average init time without old_avg * old_total to avoid overflow.
self
.
metrics
[
"avg_init_time"
]
=
(
init_time
/
new_total
)
+
(
curr_total
/
new_total
)
*
self
.
metrics
[
"avg_init_time"
]
self
.
metrics
[
"total"
]
+=
1
return
val
if
key
in
self
.
cache
:
self
.
metrics
[
"hit"
]
+=
1
val
=
self
.
cache
[
key
]
else
:
# Cache miss or disabled.
val
=
_init_with_timer
(
key
)
if
self
.
enable
:
self
.
cache
[
key
]
=
val
return
val
def
init_value
(
self
,
key
):
raise
NotImplementedError
def
get_cache_hit_rate
(
self
):
if
self
.
metrics
[
"total"
]
==
0
:
return
0
return
self
.
metrics
[
"hit"
]
/
self
.
metrics
[
"total"
]
def
get_avg_init_time
(
self
):
return
self
.
metrics
[
"avg_init_time"
]
python/sglang/srt/constrained/fast_forward.py
View file @
71b54eea
import
interegular
import
interegular
from
sglang.srt.constrained.base_cache
import
BaseCache
from
sglang.srt.constrained.disk_cache
import
disk_cache
from
sglang.srt.constrained.disk_cache
import
disk_cache
from
sglang.srt.constrained.regex
import
FSMInfo
,
make_deterministic_fsm
from
sglang.srt.constrained.regex
import
FSMInfo
,
make_deterministic_fsm
...
@@ -56,15 +57,12 @@ class FastForwardMap:
...
@@ -56,15 +57,12 @@ class FastForwardMap:
return
fast_forward_str
,
next_state
return
fast_forward_str
,
next_state
class
FastForwardCache
:
class
FastForwardCache
(
BaseCache
)
:
def
__init__
(
self
):
def
__init__
(
self
):
s
elf
.
cache
=
{}
s
uper
().
__init__
()
def
init_fast_forward_map
(
self
,
regex_string
):
def
init_value
(
self
,
regex
):
if
regex_string
not
in
self
.
cache
:
return
FastForwardMap
(
regex
)
fast_forward_map
=
FastForwardMap
(
regex_string
)
self
.
cache
[
regex_string
]
=
fast_forward_map
return
self
.
cache
[
regex_string
]
def
test_main
():
def
test_main
():
...
...
python/sglang/srt/constrained/fsm_cache.py
View file @
71b54eea
from
sglang.srt.constrained.base_cache
import
BaseCache
from
sglang.srt.constrained.fsm
import
RegexFSM
from
sglang.srt.constrained.fsm
import
RegexFSM
from
sglang.srt.constrained.tokenizer
import
TransformerTokenizer
from
sglang.srt.constrained.tokenizer
import
TransformerTokenizer
_enable_memory_cache
=
True
class
FSMCache
(
BaseCache
):
class
FSMCache
:
def
__init__
(
self
,
tokenizer_path
,
tokenizer_args_dict
,
enable
=
True
):
def
__init__
(
self
,
tokenizer_path
,
tokenizer_args_dict
):
super
().
__init__
(
enable
=
enable
)
self
.
cache
=
{}
self
.
outlines_tokenizer
=
TransformerTokenizer
(
self
.
outlines_tokenizer
=
TransformerTokenizer
(
tokenizer_path
,
**
tokenizer_args_dict
tokenizer_path
,
**
tokenizer_args_dict
)
)
def
init_fsm
(
self
,
regex
):
def
init_value
(
self
,
regex
):
if
_enable_memory_cache
:
if
regex
not
in
self
.
cache
:
fsm
=
RegexFSM
(
regex
,
self
.
outlines_tokenizer
)
self
.
cache
[
regex
]
=
fsm
return
self
.
cache
[
regex
]
return
RegexFSM
(
regex
,
self
.
outlines_tokenizer
)
return
RegexFSM
(
regex
,
self
.
outlines_tokenizer
)
python/sglang/srt/managers/router/infer_batch.py
View file @
71b54eea
...
@@ -60,7 +60,11 @@ class Req:
...
@@ -60,7 +60,11 @@ class Req:
def
tokenize_fast_forward
(
self
,
fast_forward_str
,
next_state
):
def
tokenize_fast_forward
(
self
,
fast_forward_str
,
next_state
):
old_output_str
=
self
.
tokenizer
.
decode
(
self
.
output_ids
)
old_output_str
=
self
.
tokenizer
.
decode
(
self
.
output_ids
)
if
self
.
tokenizer
.
convert_ids_to_tokens
(
self
.
output_ids
[
0
]).
startswith
(
"▁"
):
# FIXME: This logic does not really solve the problem of determining whether
# there should be a leading space.
first_token
=
self
.
tokenizer
.
convert_ids_to_tokens
(
self
.
output_ids
[
0
])
first_token
=
first_token
.
decode
()
if
isinstance
(
first_token
,
bytes
)
else
first_token
if
first_token
.
startswith
(
"▁"
):
old_output_str
=
" "
+
old_output_str
old_output_str
=
" "
+
old_output_str
new_input_string
=
(
new_input_string
=
(
self
.
input_text
self
.
input_text
...
...
python/sglang/srt/managers/router/model_rpc.py
View file @
71b54eea
...
@@ -4,8 +4,7 @@ import multiprocessing
...
@@ -4,8 +4,7 @@ import multiprocessing
import
time
import
time
import
warnings
import
warnings
from
concurrent.futures
import
ThreadPoolExecutor
from
concurrent.futures
import
ThreadPoolExecutor
from
enum
import
Enum
,
auto
from
typing
import
List
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
numpy
as
np
import
rpyc
import
rpyc
...
@@ -99,6 +98,7 @@ class ModelRpcServer(rpyc.Service):
...
@@ -99,6 +98,7 @@ class ModelRpcServer(rpyc.Service):
# Init cache
# Init cache
self
.
tree_cache
=
RadixCache
(
disable
=
"no-cache"
in
self
.
model_mode
)
self
.
tree_cache
=
RadixCache
(
disable
=
"no-cache"
in
self
.
model_mode
)
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
scheduler
=
Scheduler
(
self
.
scheduler
=
Scheduler
(
self
.
schedule_heuristic
,
self
.
schedule_heuristic
,
self
.
max_num_running_seq
,
self
.
max_num_running_seq
,
...
@@ -136,6 +136,8 @@ class ModelRpcServer(rpyc.Service):
...
@@ -136,6 +136,8 @@ class ModelRpcServer(rpyc.Service):
self
.
running_batch
is
None
or
len
(
self
.
running_batch
.
reqs
)
==
0
self
.
running_batch
is
None
or
len
(
self
.
running_batch
.
reqs
)
==
0
):
):
self
.
tree_cache
.
reset
()
self
.
tree_cache
.
reset
()
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
regex_fsm_cache
.
reset
()
self
.
req_to_token_pool
.
clear
()
self
.
req_to_token_pool
.
clear
()
self
.
token_to_kv_pool
.
clear
()
self
.
token_to_kv_pool
.
clear
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
@@ -248,9 +250,9 @@ class ModelRpcServer(rpyc.Service):
...
@@ -248,9 +250,9 @@ class ModelRpcServer(rpyc.Service):
# Init regex fsm
# Init regex fsm
if
req
.
sampling_params
.
regex
is
not
None
:
if
req
.
sampling_params
.
regex
is
not
None
:
req
.
regex_fsm
=
self
.
regex_fsm_cache
.
init_fsm
(
req
.
sampling_params
.
regex
)
req
.
regex_fsm
=
self
.
regex_fsm_cache
.
query
(
req
.
sampling_params
.
regex
)
if
not
self
.
no_regex_fast_forward
:
if
not
self
.
no_regex_fast_forward
:
req
.
fast_forward_map
=
self
.
fast_forward_cache
.
init_fast_forward_map
(
req
.
fast_forward_map
=
self
.
fast_forward_cache
.
query
(
req
.
sampling_params
.
regex
req
.
sampling_params
.
regex
)
)
...
@@ -285,7 +287,6 @@ class ModelRpcServer(rpyc.Service):
...
@@ -285,7 +287,6 @@ class ModelRpcServer(rpyc.Service):
can_run_list
=
[]
can_run_list
=
[]
new_batch_total_tokens
=
0
new_batch_total_tokens
=
0
new_batch_input_tokens
=
0
new_batch_input_tokens
=
0
new_batch_prefix_tokens
=
0
available_size
=
(
available_size
=
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
...
@@ -343,12 +344,26 @@ class ModelRpcServer(rpyc.Service):
...
@@ -343,12 +344,26 @@ class ModelRpcServer(rpyc.Service):
return
None
return
None
if
self
.
tp_rank
==
0
:
if
self
.
tp_rank
==
0
:
running_req
=
0
if
self
.
running_batch
is
None
else
len
(
self
.
running_batch
.
reqs
)
hit_tokens
=
sum
(
len
(
x
.
prefix_indices
)
for
x
in
can_run_list
)
self
.
tree_cache_metrics
[
"total"
]
+=
(
hit_tokens
+
new_batch_input_tokens
)
/
10
**
9
self
.
tree_cache_metrics
[
"hit"
]
+=
hit_tokens
/
10
**
9
tree_cache_hit_rate
=
(
self
.
tree_cache_metrics
[
"hit"
]
/
self
.
tree_cache_metrics
[
"total"
]
)
logger
.
info
(
logger
.
info
(
f
"new fill batch. #seq:
{
len
(
can_run_list
)
}
. "
f
"new fill batch. #seq:
{
len
(
can_run_list
)
}
. "
f
"#cached_token:
{
sum
(
len
(
x
.
prefix_indices
)
for
x
in
can_run_list
)
}
. "
f
"#cached_token:
{
hit_tokens
}
. "
f
"#new_token:
{
new_batch_input_tokens
}
. "
f
"#new_token:
{
new_batch_input_tokens
}
. "
f
"#remaining_req:
{
len
(
self
.
forward_queue
)
-
len
(
can_run_list
)
}
. "
f
"#remaining_req:
{
len
(
self
.
forward_queue
)
-
len
(
can_run_list
)
}
. "
f
"#running_req:
{
0
if
self
.
running_batch
is
None
else
len
(
self
.
running_batch
.
reqs
)
}
"
f
"#running_req:
{
running_req
}
. "
f
"tree_cache_hit_rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%."
)
logger
.
debug
(
f
"fsm_cache_hit_rate:
{
100.0
*
self
.
regex_fsm_cache
.
get_cache_hit_rate
():.
2
f
}
%. "
f
"fsm_cache_avg_init_time:
{
self
.
regex_fsm_cache
.
get_avg_init_time
():.
2
f
}
s. "
f
"ff_cache_hit_rate:
{
100.0
*
self
.
fast_forward_cache
.
get_cache_hit_rate
():.
2
f
}
%. "
f
"ff_cache_avg_init_time:
{
self
.
fast_forward_cache
.
get_avg_init_time
():.
2
f
}
s. "
)
)
new_batch
=
Batch
.
init_new
(
new_batch
=
Batch
.
init_new
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment