Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
bc12d403
"test/vscode:/vscode.git/clone" did not exist on "45f512f65607331a4b6d8ce3875fa63f1b07947e"
Unverified
Commit
bc12d403
authored
Oct 18, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 18, 2024
Browse files
Add grouped free operations (#1706)
parent
392f2863
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
4 deletions
+23
-4
python/sglang/srt/managers/scheduler.py
python/sglang/srt/managers/scheduler.py
+4
-0
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+19
-4
No files found.
python/sglang/srt/managers/scheduler.py
View file @
bc12d403
...
@@ -834,6 +834,8 @@ class Scheduler:
...
@@ -834,6 +834,8 @@ class Scheduler:
next_token_ids
=
self
.
resolve_next_token_ids
(
bid
,
next_token_ids
)
next_token_ids
=
self
.
resolve_next_token_ids
(
bid
,
next_token_ids
)
self
.
token_to_kv_pool
.
free_group_begin
()
# Check finish condition
# Check finish condition
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
for
i
,
(
req
,
next_token_id
)
in
enumerate
(
zip
(
batch
.
reqs
,
next_token_ids
)):
if
self
.
server_args
.
enable_overlap_schedule
and
req
.
finished
():
if
self
.
server_args
.
enable_overlap_schedule
and
req
.
finished
():
...
@@ -860,6 +862,8 @@ class Scheduler:
...
@@ -860,6 +862,8 @@ class Scheduler:
self
.
stream_output
(
batch
.
reqs
)
self
.
stream_output
(
batch
.
reqs
)
self
.
token_to_kv_pool
.
free_group_end
()
self
.
decode_forward_ct
=
(
self
.
decode_forward_ct
+
1
)
%
(
1
<<
30
)
self
.
decode_forward_ct
=
(
self
.
decode_forward_ct
+
1
)
%
(
1
<<
30
)
if
self
.
tp_rank
==
0
and
self
.
decode_forward_ct
%
40
==
0
:
if
self
.
tp_rank
==
0
and
self
.
decode_forward_ct
%
40
==
0
:
self
.
print_decode_stats
()
self
.
print_decode_stats
()
...
...
python/sglang/srt/mem_cache/memory_pool.py
View file @
bc12d403
...
@@ -18,7 +18,6 @@ limitations under the License.
...
@@ -18,7 +18,6 @@ limitations under the License.
import
logging
import
logging
from
typing
import
List
,
Tuple
,
Union
from
typing
import
List
,
Tuple
,
Union
import
numpy
as
np
import
torch
import
torch
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -77,6 +76,8 @@ class BaseTokenToKVPool:
...
@@ -77,6 +76,8 @@ class BaseTokenToKVPool:
self
.
store_dtype
=
dtype
self
.
store_dtype
=
dtype
self
.
free_slots
=
None
self
.
free_slots
=
None
self
.
is_not_in_free_group
=
True
self
.
free_group
=
[]
self
.
clear
()
self
.
clear
()
def
available_size
(
self
):
def
available_size
(
self
):
...
@@ -89,14 +90,28 @@ class BaseTokenToKVPool:
...
@@ -89,14 +90,28 @@ class BaseTokenToKVPool:
select_index
=
self
.
free_slots
[:
need_size
]
select_index
=
self
.
free_slots
[:
need_size
]
self
.
free_slots
=
self
.
free_slots
[
need_size
:]
self
.
free_slots
=
self
.
free_slots
[
need_size
:]
return
torch
.
tensor
(
select_index
,
dtype
=
torch
.
int32
,
device
=
self
.
device
)
return
select_index
.
to
(
self
.
device
)
def
free
(
self
,
free_index
:
torch
.
Tensor
):
def
free
(
self
,
free_index
:
torch
.
Tensor
):
self
.
free_slots
=
np
.
concatenate
((
self
.
free_slots
,
free_index
.
cpu
().
numpy
()))
if
self
.
is_not_in_free_group
:
self
.
free_slots
=
torch
.
concat
((
self
.
free_slots
,
free_index
.
cpu
()))
else
:
self
.
free_group
.
append
(
free_index
)
def
free_group_begin
(
self
):
self
.
is_not_in_free_group
=
False
self
.
free_group
=
[]
def
free_group_end
(
self
):
self
.
is_not_in_free_group
=
True
if
self
.
free_group
:
self
.
free
(
torch
.
concat
(
self
.
free_group
))
def
clear
(
self
):
def
clear
(
self
):
# The padded slot 0 is used for writing dummy outputs from padded tokens.
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self
.
free_slots
=
np
.
arange
(
1
,
self
.
size
+
1
)
self
.
free_slots
=
torch
.
arange
(
1
,
self
.
size
+
1
,
dtype
=
torch
.
int32
)
self
.
is_in_free_group
=
False
self
.
free_group
=
[]
def
get_key_buffer
(
self
,
layer_id
:
int
)
->
torch
.
Tensor
:
def
get_key_buffer
(
self
,
layer_id
:
int
)
->
torch
.
Tensor
:
raise
NotImplementedError
()
raise
NotImplementedError
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment