Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
bc12d403
Unverified
Commit
bc12d403
authored
Oct 18, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 18, 2024
Browse files
Add grouped free operations (#1706)
parent
392f2863
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
4 deletions
+23
-4
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+4
-0
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+19
-4
No files found.
python/sglang/srt/managers/scheduler.py
View file @
bc12d403
...
...
@@ -834,6 +834,8 @@ class Scheduler:
next_token_ids
=
self
.
resolve_next_token_ids
(
bid
,
next_token_ids
)
self
.
token_to_kv_pool
.
free_group_begin
()
# Check finish condition
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
if
self
.
server_args
.
enable_overlap_schedule
and
req
.
finished
():
...
...
@@ -860,6 +862,8 @@ class Scheduler:
self
.
stream_output
(
batch
.
reqs
)
self
.
token_to_kv_pool
.
free_group_end
()
self
.
decode_forward_ct
=
(
self
.
decode_forward_ct
+
1
)
%
(
1
<<
30
)
if
self
.
tp_rank
==
0
and
self
.
decode_forward_ct
%
40
==
0
:
self
.
print_decode_stats
()
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
bc12d403
...
...
@@ -18,7 +18,6 @@ limitations under the License.
import
logging
from
typing
import
List
,
Tuple
,
Union
import
numpy
as
np
import
torch
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -77,6 +76,8 @@ class BaseTokenToKVPool:
self
.
store_dtype
=
dtype
self
.
free_slots
=
None
self
.
is_not_in_free_group
=
True
self
.
free_group
=
[]
self
.
clear
()
def
available_size
(
self
):
...
...
@@ -89,14 +90,28 @@ class BaseTokenToKVPool:
select_index
=
self
.
free_slots
[:
need_size
]
self
.
free_slots
=
self
.
free_slots
[
need_size
:]
return
torch
.
tensor
(
select_index
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
return
select_index
.
to
(
self
.
device
)
def
free
(
self
,
free_index
:
torch
.
Tensor
):
self
.
free_slots
=
np
.
concatenate
((
self
.
free_slots
,
free_index
.
cpu
().
numpy
()))
if
self
.
is_not_in_free_group
:
self
.
free_slots
=
torch
.
concat
((
self
.
free_slots
,
free_index
.
cpu
()))
else
:
self
.
free_group
.
append
(
free_index
)
def
free_group_begin
(
self
):
self
.
is_not_in_free_group
=
False
self
.
free_group
=
[]
def
free_group_end
(
self
):
self
.
is_not_in_free_group
=
True
if
self
.
free_group
:
self
.
free
(
torch
.
concat
(
self
.
free_group
))
def
clear
(
self
):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self
.
free_slots
=
np
.
arange
(
1
,
self
.
size
+
1
)
self
.
free_slots
=
torch
.
arange
(
1
,
self
.
size
+
1
,
dtype
=
torch
.
int32
)
self
.
is_in_free_group
=
False
self
.
free_group
=
[]
def
get_key_buffer
(
self
,
layer_id
:
int
)
->
torch
.
Tensor
:
raise
NotImplementedError
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment