Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
4ae0969c
Unverified
Commit
4ae0969c
authored
Oct 02, 2024
by
Lianmin Zheng
Committed by
GitHub
Oct 02, 2024
Browse files
Move status check in the memory pool to CPU (#1557)
parent
317631ca
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
42 deletions
+27
-42
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+21
-41
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+6
-1
No files found.
python/sglang/srt/mem_cache/memory_pool.py
View file @
4ae0969c
...
@@ -19,6 +19,7 @@ import logging
...
@@ -19,6 +19,7 @@ import logging
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
List
,
Tuple
,
Union
from
typing
import
List
,
Tuple
,
Union
import
numpy
as
np
import
torch
import
torch
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -69,56 +70,27 @@ class BaseTokenToKVPool(ABC):
...
@@ -69,56 +70,27 @@ class BaseTokenToKVPool(ABC):
else
:
else
:
self
.
store_dtype
=
dtype
self
.
store_dtype
=
dtype
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self
.
free_slots
=
None
self
.
mem_state
=
torch
.
ones
((
self
.
size
+
1
,),
dtype
=
torch
.
bool
,
device
=
"cuda"
)
# Prefetch buffer
self
.
prefetch_buffer
=
torch
.
empty
(
0
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
self
.
prefetch_chunk_size
=
512
self
.
can_use_mem_size
=
self
.
size
self
.
clear
()
self
.
clear
()
def
available_size
(
self
):
def
available_size
(
self
):
return
self
.
can_use_mem_size
+
len
(
self
.
p
re
fetch_buffer
)
return
len
(
self
.
f
re
e_slots
)
def
alloc
(
self
,
need_size
:
int
):
def
alloc
(
self
,
need_size
:
int
):
buffer_len
=
len
(
self
.
prefetch_buffer
)
if
need_size
>
len
(
self
.
free_slots
):
if
need_size
<=
buffer_len
:
select_index
=
self
.
prefetch_buffer
[:
need_size
]
self
.
prefetch_buffer
=
self
.
prefetch_buffer
[
need_size
:]
return
select_index
addition_size
=
need_size
-
buffer_len
alloc_size
=
max
(
addition_size
,
self
.
prefetch_chunk_size
)
select_index
=
(
torch
.
nonzero
(
self
.
mem_state
).
squeeze
(
1
)[:
alloc_size
].
to
(
torch
.
int32
)
)
if
select_index
.
shape
[
0
]
<
addition_size
:
return
None
return
None
self
.
mem_state
[
select_index
]
=
False
select_index
=
self
.
free_slots
[:
need_size
]
self
.
can_use_mem_size
-=
len
(
select_index
)
self
.
free_slots
=
self
.
free_slots
[
need_size
:]
self
.
prefetch_buffer
=
torch
.
cat
((
self
.
prefetch_buffer
,
select_index
))
ret_index
=
self
.
prefetch_buffer
[:
need_size
]
self
.
prefetch_buffer
=
self
.
prefetch_buffer
[
need_size
:]
return
ret_index
return
torch
.
tensor
(
select_index
,
dtype
=
torch
.
int32
,
device
=
"cuda"
)
def
free
(
self
,
free_index
:
torch
.
Tensor
):
def
free
(
self
,
free_index
:
torch
.
Tensor
):
self
.
mem_state
[
free_index
]
=
True
self
.
free_slots
=
np
.
concatenate
((
self
.
free_slots
,
free_index
.
cpu
().
numpy
()))
self
.
can_use_mem_size
+=
len
(
free_index
)
def
clear
(
self
):
def
clear
(
self
):
self
.
prefetch_buffer
=
torch
.
empty
(
0
,
device
=
"cuda"
,
dtype
=
torch
.
int32
)
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self
.
free_slots
=
np
.
arange
(
1
,
self
.
size
+
1
)
self
.
mem_state
.
fill_
(
True
)
self
.
can_use_mem_size
=
self
.
size
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self
.
mem_state
[
0
]
=
False
@
abstractmethod
@
abstractmethod
def
get_key_buffer
(
self
,
layer_id
:
int
)
->
torch
.
Tensor
:
def
get_key_buffer
(
self
,
layer_id
:
int
)
->
torch
.
Tensor
:
...
@@ -152,19 +124,25 @@ class MHATokenToKVPool(BaseTokenToKVPool):
...
@@ -152,19 +124,25 @@ class MHATokenToKVPool(BaseTokenToKVPool):
head_num
:
int
,
head_num
:
int
,
head_dim
:
int
,
head_dim
:
int
,
layer_num
:
int
,
layer_num
:
int
,
device
:
str
,
):
):
super
().
__init__
(
size
,
dtype
)
super
().
__init__
(
size
,
dtype
)
# [size, head_num, head_dim] for each layer
# [size, head_num, head_dim] for each layer
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self
.
k_buffer
=
[
self
.
k_buffer
=
[
torch
.
empty
(
torch
.
empty
(
(
size
+
1
,
head_num
,
head_dim
),
dtype
=
self
.
store_dtype
,
device
=
"cuda"
(
size
+
1
,
head_num
,
head_dim
),
dtype
=
self
.
store_dtype
,
device
=
device
,
)
)
for
_
in
range
(
layer_num
)
for
_
in
range
(
layer_num
)
]
]
self
.
v_buffer
=
[
self
.
v_buffer
=
[
torch
.
empty
(
torch
.
empty
(
(
size
+
1
,
head_num
,
head_dim
),
dtype
=
self
.
store_dtype
,
device
=
"cuda"
(
size
+
1
,
head_num
,
head_dim
),
dtype
=
self
.
store_dtype
,
device
=
device
,
)
)
for
_
in
range
(
layer_num
)
for
_
in
range
(
layer_num
)
]
]
...
@@ -210,15 +188,17 @@ class MLATokenToKVPool(BaseTokenToKVPool):
...
@@ -210,15 +188,17 @@ class MLATokenToKVPool(BaseTokenToKVPool):
kv_lora_rank
:
int
,
kv_lora_rank
:
int
,
qk_rope_head_dim
:
int
,
qk_rope_head_dim
:
int
,
layer_num
:
int
,
layer_num
:
int
,
device
:
str
,
):
):
super
().
__init__
(
size
,
dtype
)
super
().
__init__
(
size
,
dtype
)
self
.
kv_lora_rank
=
kv_lora_rank
self
.
kv_lora_rank
=
kv_lora_rank
# The padded slot 0 is used for writing dummy outputs from padded tokens.
self
.
kv_buffer
=
[
self
.
kv_buffer
=
[
torch
.
empty
(
torch
.
empty
(
(
size
+
1
,
1
,
kv_lora_rank
+
qk_rope_head_dim
),
(
size
+
1
,
1
,
kv_lora_rank
+
qk_rope_head_dim
),
dtype
=
self
.
store_dtype
,
dtype
=
self
.
store_dtype
,
device
=
"cuda"
,
device
=
device
,
)
)
for
_
in
range
(
layer_num
)
for
_
in
range
(
layer_num
)
]
]
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
4ae0969c
...
@@ -409,8 +409,11 @@ class ModelRunner:
...
@@ -409,8 +409,11 @@ class ModelRunner:
4096
,
4096
,
)
)
device
=
"cuda"
self
.
req_to_token_pool
=
ReqToTokenPool
(
self
.
req_to_token_pool
=
ReqToTokenPool
(
max_num_reqs
+
1
,
self
.
model_config
.
context_len
+
4
,
device
=
"cuda"
max_num_reqs
+
1
,
self
.
model_config
.
context_len
+
4
,
device
=
device
,
)
)
if
(
if
(
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
self
.
model_config
.
attention_arch
==
AttentionArch
.
MLA
...
@@ -422,6 +425,7 @@ class ModelRunner:
...
@@ -422,6 +425,7 @@ class ModelRunner:
kv_lora_rank
=
self
.
model_config
.
kv_lora_rank
,
kv_lora_rank
=
self
.
model_config
.
kv_lora_rank
,
qk_rope_head_dim
=
self
.
model_config
.
qk_rope_head_dim
,
qk_rope_head_dim
=
self
.
model_config
.
qk_rope_head_dim
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
device
=
device
,
)
)
else
:
else
:
self
.
token_to_kv_pool
=
MHATokenToKVPool
(
self
.
token_to_kv_pool
=
MHATokenToKVPool
(
...
@@ -430,6 +434,7 @@ class ModelRunner:
...
@@ -430,6 +434,7 @@ class ModelRunner:
head_num
=
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
),
head_num
=
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
),
head_dim
=
self
.
model_config
.
head_dim
,
head_dim
=
self
.
model_config
.
head_dim
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
device
=
device
,
)
)
logger
.
info
(
logger
.
info
(
f
"Memory pool end. "
f
"Memory pool end. "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment