Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
19818b9c
Unverified
Commit
19818b9c
authored
Apr 26, 2024
by
Liangsheng Yin
Committed by
GitHub
Apr 26, 2024
Browse files
Minor: style improvement of radix_cache and memory_pool (#395)
parent
9216b106
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
26 additions
and
37 deletions
+26
-37
python/sglang/srt/managers/router/infer_batch.py
python/sglang/srt/managers/router/infer_batch.py
+6
-7
python/sglang/srt/managers/router/model_rpc.py
python/sglang/srt/managers/router/model_rpc.py
+2
-2
python/sglang/srt/managers/router/radix_cache.py
python/sglang/srt/managers/router/radix_cache.py
+11
-12
python/sglang/srt/memory_pool.py
python/sglang/srt/memory_pool.py
+5
-14
python/sglang/srt/server.py
python/sglang/srt/server.py
+1
-1
test/srt/model/test_llama_low_api.py
test/srt/model/test_llama_low_api.py
+1
-1
No files found.
python/sglang/srt/managers/router/infer_batch.py
View file @
19818b9c
...
...
@@ -236,9 +236,8 @@ class Batch:
extend_num_tokens
=
seq_lens
.
sum
()
-
prefix_lens
.
sum
()
out_cache_loc
=
self
.
token_to_kv_pool
.
alloc
(
extend_num_tokens
)
if
out_cache_loc
is
None
:
if
not
self
.
tree_cache
.
disable
:
self
.
tree_cache
.
evict
(
extend_num_tokens
,
self
.
token_to_kv_pool
.
free
)
out_cache_loc
=
self
.
token_to_kv_pool
.
alloc
(
extend_num_tokens
)
self
.
tree_cache
.
evict
(
extend_num_tokens
,
self
.
token_to_kv_pool
.
dec_refs
)
out_cache_loc
=
self
.
token_to_kv_pool
.
alloc
(
extend_num_tokens
)
if
out_cache_loc
is
None
:
print
(
"Prefill out of memory. This should never happen."
)
...
...
@@ -307,8 +306,8 @@ class Batch:
if
self
.
token_to_kv_pool
.
available_size
()
>=
bs
:
return
True
if
not
self
.
tree_cache
.
disable
:
self
.
tree_cache
.
evict
(
bs
,
self
.
token_to_kv_pool
.
free
)
self
.
tree_cache
.
evict
(
bs
,
self
.
token_to_kv_pool
.
dec_refs
)
if
self
.
token_to_kv_pool
.
available_size
()
>=
bs
:
return
True
...
...
@@ -341,7 +340,7 @@ class Batch:
token_indices
=
self
.
req_to_token_pool
.
req_to_token
[
req_pool_indices_np
[
idx
]
][:
seq_lens_np
[
idx
]]
self
.
token_to_kv_pool
.
free
(
token_indices
)
self
.
token_to_kv_pool
.
dec_refs
(
token_indices
)
self
.
filter_batch
(
sorted_indices
)
...
...
@@ -372,7 +371,7 @@ class Batch:
prefix_len
=
self
.
tree_cache
.
insert
(
token_ids_in_memory
,
indices
.
clone
()
)
self
.
token_to_kv_pool
.
free
(
indices
[:
prefix_len
])
self
.
token_to_kv_pool
.
dec_refs
(
indices
[:
prefix_len
])
self
.
req_to_token_pool
.
free
(
req_pool_idx
)
self
.
tree_cache
.
dec_ref_counter
(
req
.
last_node
)
...
...
python/sglang/srt/managers/router/model_rpc.py
View file @
19818b9c
...
...
@@ -113,7 +113,7 @@ class ModelRpcServer:
logger
.
info
(
server_args
.
get_optional_modes_logging
())
# Init cache
self
.
tree_cache
=
RadixCache
(
server_args
.
disable_radix_cache
)
self
.
tree_cache
=
RadixCache
(
disable
=
server_args
.
disable_radix_cache
)
self
.
tree_cache_metrics
=
{
"total"
:
0
,
"hit"
:
0
}
self
.
scheduler
=
Scheduler
(
self
.
schedule_heuristic
,
...
...
@@ -628,7 +628,7 @@ class ModelRpcServer:
token_ids
[:
seq_len
],
indices
.
clone
()
)
self
.
token_to_kv_pool
.
free
(
indices
[:
prefix_len
])
self
.
token_to_kv_pool
.
dec_refs
(
indices
[:
prefix_len
])
self
.
req_to_token_pool
.
free
(
req_pool_idx
)
self
.
tree_cache
.
dec_ref_counter
(
req
.
last_node
)
...
...
python/sglang/srt/managers/router/radix_cache.py
View file @
19818b9c
import
heapq
import
time
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
typing
import
Tuple
import
torch
...
...
@@ -16,23 +14,23 @@ class TreeNode:
self
.
ref_counter
=
0
self
.
last_access_time
=
time
.
time
()
def
__lt__
(
self
,
other
):
def
__lt__
(
self
,
other
:
"TreeNode"
):
return
self
.
last_access_time
<
other
.
last_access_time
def
match
(
key
,
seq
):
def
_key_
match
(
key
0
,
key1
):
i
=
0
for
k
,
w
in
zip
(
key
,
seq
):
if
k
!=
w
:
for
k
0
,
k1
in
zip
(
key
0
,
key1
):
if
k
0
!=
k1
:
break
i
+=
1
return
i
class
RadixCache
:
def
__init__
(
self
,
disable
=
False
):
self
.
reset
()
def
__init__
(
self
,
disable
:
bool
=
False
):
self
.
disable
=
disable
self
.
reset
()
##### Public API #####
...
...
@@ -71,7 +69,7 @@ class RadixCache:
def
evict
(
self
,
num_tokens
,
evict_callback
):
if
self
.
disable
:
r
aise
RuntimeError
()
r
eturn
leaves
=
self
.
_collect_leaves
()
heapq
.
heapify
(
leaves
)
...
...
@@ -115,6 +113,7 @@ class RadixCache:
return
self
.
evictable_size_
##### Internal Helper Functions #####
def
_match_prefix_helper
(
self
,
node
,
key
,
value
,
last_node
):
node
.
last_access_time
=
time
.
time
()
if
len
(
key
)
==
0
:
...
...
@@ -122,7 +121,7 @@ class RadixCache:
if
key
[
0
]
in
node
.
children
.
keys
():
child
=
node
.
children
[
key
[
0
]]
prefix_len
=
match
(
child
.
key
,
key
)
prefix_len
=
_key_
match
(
child
.
key
,
key
)
if
prefix_len
<
len
(
child
.
key
):
new_node
=
self
.
_split_node
(
child
.
key
,
child
,
prefix_len
)
value
.
append
(
new_node
.
value
)
...
...
@@ -153,7 +152,7 @@ class RadixCache:
if
key
[
0
]
in
node
.
children
.
keys
():
child
=
node
.
children
[
key
[
0
]]
prefix_len
=
match
(
child
.
key
,
key
)
prefix_len
=
_key_
match
(
child
.
key
,
key
)
if
prefix_len
==
len
(
child
.
key
):
if
prefix_len
==
len
(
key
):
...
...
@@ -212,7 +211,7 @@ class RadixCache:
if
__name__
==
"__main__"
:
tree
=
RadixCache
(
disable
=
False
)
tree
=
RadixCache
()
tree
.
insert
(
"Hello"
)
tree
.
insert
(
"Hello"
)
...
...
python/sglang/srt/memory_pool.py
View file @
19818b9c
...
...
@@ -31,9 +31,6 @@ class ReqToTokenPool:
self
.
can_use_mem_size
+=
free_index
.
shape
[
0
]
self
.
mem_state
[
free_index
]
=
1
# if self.can_use_mem_size == len(self.mem_state):
# print(f"ReqToTokenPool: freed all. size = {self.can_use_mem_size}.")
def
clear
(
self
):
self
.
mem_state
.
fill_
(
1
)
self
.
can_use_mem_size
=
len
(
self
.
mem_state
)
...
...
@@ -42,7 +39,7 @@ class ReqToTokenPool:
class
TokenToKVPool
:
def
__init__
(
self
,
size
,
dtype
,
head_num
,
head_dim
,
layer_num
):
self
.
mem_state
=
torch
.
zeros
((
size
,),
dtype
=
torch
.
int16
,
device
=
"cuda"
)
self
.
alloc
_ct
=
0
self
.
total_ref
_ct
=
0
# [size, key/value, head_num, head_dim] for each layer
self
.
kv_data
=
[
...
...
@@ -83,9 +80,6 @@ class TokenToKVPool:
self
.
add_refs
(
select_index
)
return
select_index
.
to
(
torch
.
int32
),
start_loc
,
start_loc
+
need_size
def
free
(
self
,
free_index
):
return
self
.
decrease_refs
(
free_index
)
def
used_size
(
self
):
return
len
(
torch
.
nonzero
(
self
.
mem_state
).
squeeze
(
1
))
...
...
@@ -93,20 +87,17 @@ class TokenToKVPool:
return
torch
.
sum
(
self
.
mem_state
==
0
).
item
()
def
add_refs
(
self
,
token_index
:
torch
.
Tensor
):
self
.
alloc
_ct
+=
len
(
token_index
)
self
.
total_ref
_ct
+=
len
(
token_index
)
self
.
mem_state
[
token_index
]
+=
1
def
dec
rease
_refs
(
self
,
token_index
:
torch
.
Tensor
):
self
.
alloc
_ct
-=
len
(
token_index
)
def
dec_refs
(
self
,
token_index
:
torch
.
Tensor
):
self
.
total_ref
_ct
-=
len
(
token_index
)
self
.
mem_state
[
token_index
]
-=
1
num_freed
=
torch
.
sum
(
self
.
mem_state
[
token_index
]
==
0
)
# if self.alloc_ct == 0:
# print(f"TokenToKVPool: freed all. size = {len(self.mem_state)}.")
return
num_freed
def
clear
(
self
):
self
.
mem_state
.
fill_
(
0
)
self
.
alloc
_ct
=
0
self
.
total_ref
_ct
=
0
python/sglang/srt/server.py
View file @
19818b9c
...
...
@@ -500,7 +500,7 @@ async def v1_chat_completions(raw_request: Request):
return
response
def
launch_server
(
server_args
,
pipe_finish_writer
):
def
launch_server
(
server_args
:
ServerArgs
,
pipe_finish_writer
):
global
tokenizer_manager
global
chat_template_name
...
...
test/srt/model/test_llama_low_api.py
View file @
19818b9c
...
...
@@ -105,7 +105,7 @@ def test_generate_worker(
for
i
in
range
(
batch_size
):
req_idx
=
req_pool_indices
[
i
].
item
()
model
.
token_to_kv_pool
.
free
(
model
.
token_to_kv_pool
.
dec_refs
(
model
.
req_to_token_pool
.
req_to_token
[
req_idx
,
:
seq_lens
[
i
]]
)
model
.
req_to_token_pool
.
free
(
req_pool_indices
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment