Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
71b54eea
Unverified
Commit
71b54eea
authored
Jan 30, 2024
by
Cody Yu
Committed by
GitHub
Jan 30, 2024
Browse files
Add cache metrics (#119)
parent
74b3bfaa
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
87 additions
and
27 deletions
+87
-27
python/sglang/srt/constrained/base_cache.py
python/sglang/srt/constrained/base_cache.py
+50
-0
python/sglang/srt/constrained/fast_forward.py
python/sglang/srt/constrained/fast_forward.py
+5
-7
python/sglang/srt/constrained/fsm_cache.py
python/sglang/srt/constrained/fsm_cache.py
+5
-12
python/sglang/srt/managers/router/infer_batch.py
python/sglang/srt/managers/router/infer_batch.py
+5
-1
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+22
-7
No files found.
python/sglang/srt/constrained/base_cache.py
0 → 100644
View file @
71b54eea
"""Base cache class."""
import
time
class
BaseCache
:
def
__init__
(
self
,
enable
=
True
):
self
.
enable
=
enable
self
.
reset
()
def
reset
(
self
):
self
.
cache
=
{}
self
.
metrics
=
{
"total"
:
0
,
"hit"
:
0
,
"avg_init_time"
:
0
}
def
query
(
self
,
key
):
def
_init_with_timer
(
key
):
start
=
time
.
monotonic
()
val
=
self
.
init_value
(
key
)
init_time
=
time
.
monotonic
()
-
start
curr_total
=
self
.
metrics
[
"total"
]
new_total
=
curr_total
+
1
# Update average init time without old_avg * old_total to avoid overflow.
self
.
metrics
[
"avg_init_time"
]
=
(
init_time
/
new_total
)
+
(
curr_total
/
new_total
)
*
self
.
metrics
[
"avg_init_time"
]
self
.
metrics
[
"total"
]
+=
1
return
val
if
key
in
self
.
cache
:
self
.
metrics
[
"hit"
]
+=
1
val
=
self
.
cache
[
key
]
else
:
# Cache miss or disabled.
val
=
_init_with_timer
(
key
)
if
self
.
enable
:
self
.
cache
[
key
]
=
val
return
val
def
init_value
(
self
,
key
):
raise
NotImplementedError
def
get_cache_hit_rate
(
self
):
if
self
.
metrics
[
"total"
]
==
0
:
return
0
return
self
.
metrics
[
"hit"
]
/
self
.
metrics
[
"total"
]
def
get_avg_init_time
(
self
):
return
self
.
metrics
[
"avg_init_time"
]
python/sglang/srt/constrained/fast_forward.py
View file @
71b54eea
import
interegular
from
sglang.srt.constrained.base_cache
import
BaseCache
from
sglang.srt.constrained.disk_cache
import
disk_cache
from
sglang.srt.constrained.regex
import
FSMInfo
,
make_deterministic_fsm
...
...
@@ -56,15 +57,12 @@ class FastForwardMap:
return
fast_forward_str
,
next_state
class
FastForwardCache
:
class
FastForwardCache
(
BaseCache
)
:
def
__init__
(
self
):
s
elf
.
cache
=
{}
s
uper
().
__init__
()
def
init_fast_forward_map
(
self
,
regex_string
):
if
regex_string
not
in
self
.
cache
:
fast_forward_map
=
FastForwardMap
(
regex_string
)
self
.
cache
[
regex_string
]
=
fast_forward_map
return
self
.
cache
[
regex_string
]
def
init_value
(
self
,
regex
):
return
FastForwardMap
(
regex
)
def
test_main
():
...
...
python/sglang/srt/constrained/fsm_cache.py
View file @
71b54eea
from
sglang.srt.constrained.base_cache
import
BaseCache
from
sglang.srt.constrained.fsm
import
RegexFSM
from
sglang.srt.constrained.tokenizer
import
TransformerTokenizer
_enable_memory_cache
=
True
class
FSMCache
:
def
__init__
(
self
,
tokenizer_path
,
tokenizer_args_dict
):
self
.
cache
=
{}
class
FSMCache
(
BaseCache
):
def
__init__
(
self
,
tokenizer_path
,
tokenizer_args_dict
,
enable
=
True
):
super
().
__init__
(
enable
=
enable
)
self
.
outlines_tokenizer
=
TransformerTokenizer
(
tokenizer_path
,
**
tokenizer_args_dict
)
def
init_fsm
(
self
,
regex
):
if
_enable_memory_cache
:
if
regex
not
in
self
.
cache
:
fsm
=
RegexFSM
(
regex
,
self
.
outlines_tokenizer
)
self
.
cache
[
regex
]
=
fsm
return
self
.
cache
[
regex
]
def
init_value
(
self
,
regex
):
return
RegexFSM
(
regex
,
self
.
outlines_tokenizer
)
python/sglang/srt/managers/router/infer_batch.py
View file @
71b54eea
...
...
@@ -60,7 +60,11 @@ class Req:
def
tokenize_fast_forward
(
self
,
fast_forward_str
,
next_state
):
old_output_str
=
self
.
tokenizer
.
decode
(
self
.
output_ids
)
if
self
.
tokenizer
.
convert_ids_to_tokens
(
self
.
output_ids
[
0
]).
startswith
(
"▁"
):
# FIXME: This logic does not really solve the problem of determining whether
# there should be a leading space.
first_token
=
self
.
tokenizer
.
convert_ids_to_tokens
(
self
.
output_ids
[
0
])
first_token
=
first_token
.
decode
()
if
isinstance
(
first_token
,
bytes
)
else
first_token
if
first_token
.
startswith
(
"▁"
):
old_output_str
=
" "
+
old_output_str
new_input_string
=
(
self
.
input_text
...
...
python/sglang/srt/managers/router/model_rpc.py
View file @
71b54eea
...
...
@@ -4,8 +4,7 @@ import multiprocessing
import
time
import
warnings
from
concurrent.futures
import
ThreadPoolExecutor
from
enum
import
Enum
,
auto
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
List
import
numpy
as
np
import
rpyc
...
...
@@ -99,6 +98,7 @@ class ModelRpcServer(rpyc.Service):
# Init cache
self
.
tree_cache
=
RadixCache
(
disable
=
"no-cache"
in
self
.
model_mode
)
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
scheduler
=
Scheduler
(
self
.
schedule_heuristic
,
self
.
max_num_running_seq
,
...
...
@@ -136,6 +136,8 @@ class ModelRpcServer(rpyc.Service):
self
.
running_batch
is
None
or
len
(
self
.
running_batch
.
reqs
)
==
0
):
self
.
tree_cache
.
reset
()
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
regex_fsm_cache
.
reset
()
self
.
req_to_token_pool
.
clear
()
self
.
token_to_kv_pool
.
clear
()
torch
.
cuda
.
empty_cache
()
...
...
@@ -248,9 +250,9 @@ class ModelRpcServer(rpyc.Service):
# Init regex fsm
if
req
.
sampling_params
.
regex
is
not
None
:
req
.
regex_fsm
=
self
.
regex_fsm_cache
.
init_fsm
(
req
.
sampling_params
.
regex
)
req
.
regex_fsm
=
self
.
regex_fsm_cache
.
query
(
req
.
sampling_params
.
regex
)
if
not
self
.
no_regex_fast_forward
:
req
.
fast_forward_map
=
self
.
fast_forward_cache
.
init_fast_forward_map
(
req
.
fast_forward_map
=
self
.
fast_forward_cache
.
query
(
req
.
sampling_params
.
regex
)
...
...
@@ -285,7 +287,6 @@ class ModelRpcServer(rpyc.Service):
can_run_list
=
[]
new_batch_total_tokens
=
0
new_batch_input_tokens
=
0
new_batch_prefix_tokens
=
0
available_size
=
(
self
.
token_to_kv_pool
.
available_size
()
+
self
.
tree_cache
.
evictable_size
()
...
...
@@ -343,12 +344,26 @@ class ModelRpcServer(rpyc.Service):
return
None
if
self
.
tp_rank
==
0
:
running_req
=
0
if
self
.
running_batch
is
None
else
len
(
self
.
running_batch
.
reqs
)
hit_tokens
=
sum
(
len
(
x
.
prefix_indices
)
for
x
in
can_run_list
)
self
.
tree_cache_metrics
[
"total"
]
+=
(
hit_tokens
+
new_batch_input_tokens
)
/
10
**
9
self
.
tree_cache_metrics
[
"hit"
]
+=
hit_tokens
/
10
**
9
tree_cache_hit_rate
=
(
self
.
tree_cache_metrics
[
"hit"
]
/
self
.
tree_cache_metrics
[
"total"
]
)
logger
.
info
(
f
"new fill batch. #seq:
{
len
(
can_run_list
)
}
. "
f
"#cached_token:
{
sum
(
len
(
x
.
prefix_indices
)
for
x
in
can_run_list
)
}
. "
f
"#cached_token:
{
hit_tokens
}
. "
f
"#new_token:
{
new_batch_input_tokens
}
. "
f
"#remaining_req:
{
len
(
self
.
forward_queue
)
-
len
(
can_run_list
)
}
. "
f
"#running_req:
{
0
if
self
.
running_batch
is
None
else
len
(
self
.
running_batch
.
reqs
)
}
"
f
"#running_req:
{
running_req
}
. "
f
"tree_cache_hit_rate:
{
100.0
*
tree_cache_hit_rate
:.
2
f
}
%."
)
logger
.
debug
(
f
"fsm_cache_hit_rate:
{
100.0
*
self
.
regex_fsm_cache
.
get_cache_hit_rate
():.
2
f
}
%. "
f
"fsm_cache_avg_init_time:
{
self
.
regex_fsm_cache
.
get_avg_init_time
():.
2
f
}
s. "
f
"ff_cache_hit_rate:
{
100.0
*
self
.
fast_forward_cache
.
get_cache_hit_rate
():.
2
f
}
%. "
f
"ff_cache_avg_init_time:
{
self
.
fast_forward_cache
.
get_avg_init_time
():.
2
f
}
s. "
)
new_batch
=
Batch
.
init_new
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment