Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
476584cb
Unverified
Commit
476584cb
authored
Jul 17, 2024
by
Ying Sheng
Committed by
GitHub
Jul 17, 2024
Browse files
Increase the capacity of the memory pool (#643)
parent
abd5385a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
18 additions
and
16 deletions
+18
-16
python/sglang/srt/managers/controller/cuda_graph_runner.py
python/sglang/srt/managers/controller/cuda_graph_runner.py
+2
-4
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+5
-0
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+7
-8
python/sglang/srt/memory_pool.py
python/sglang/srt/memory_pool.py
+4
-4
No files found.
python/sglang/srt/managers/controller/cuda_graph_runner.py
View file @
476584cb
...
...
@@ -3,9 +3,10 @@
import
bisect
import
torch
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
from
vllm.distributed.parallel_state
import
graph_capture
from
sglang.global_config
import
global_config
from
sglang.srt.layers.logits_processor
import
LogitProcessorOutput
from
sglang.srt.managers.controller.infer_batch
import
(
Batch
,
...
...
@@ -74,9 +75,6 @@ class CudaGraphRunner:
self
.
flashinfer_handlers
[
bs
]
=
flashinfer_handler
def
capture_one_batch_size
(
self
,
bs
):
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
graph
=
torch
.
cuda
.
CUDAGraph
()
stream
=
self
.
stream
...
...
python/sglang/srt/managers/controller/infer_batch.py
View file @
476584cb
...
...
@@ -325,6 +325,11 @@ class Batch:
seq_lens
=
[]
req_pool_indices
=
self
.
req_to_token_pool
.
alloc
(
bs
)
if
req_pool_indices
is
None
:
raise
RuntimeError
(
"Out of memory. "
"Please set a smaller number for `--max-running-requests`."
)
req_pool_indices_cpu
=
req_pool_indices
.
cpu
().
numpy
()
for
i
in
range
(
bs
):
flatten_input_ids
.
extend
(
input_ids
[
i
])
...
...
python/sglang/srt/managers/controller/model_runner.py
View file @
476584cb
...
...
@@ -9,6 +9,12 @@ from typing import Optional, Type
import
torch
import
torch.nn
as
nn
from
flashinfer
import
(
BatchDecodeWithPagedKVCacheWrapper
,
BatchPrefillWithPagedKVCacheWrapper
,
BatchPrefillWithRaggedKVCacheWrapper
,
)
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
from
vllm.config
import
DeviceConfig
,
LoadConfig
from
vllm.config
import
ModelConfig
as
VllmModelConfig
from
vllm.distributed
import
(
...
...
@@ -162,7 +168,7 @@ class ModelRunner:
)
self
.
req_to_token_pool
=
ReqToTokenPool
(
int
(
self
.
max_total_num_tokens
/
self
.
model_config
.
context_len
*
256
),
max
(
int
(
self
.
max_total_num_tokens
/
self
.
model_config
.
context_len
*
512
),
2048
),
self
.
model_config
.
context_len
+
8
,
)
self
.
token_to_kv_pool
=
TokenToKVPool
(
...
...
@@ -193,13 +199,6 @@ class ModelRunner:
self
.
flashinfer_decode_wrapper
=
None
return
from
flashinfer
import
(
BatchDecodeWithPagedKVCacheWrapper
,
BatchPrefillWithPagedKVCacheWrapper
,
BatchPrefillWithRaggedKVCacheWrapper
,
)
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
if
not
_grouped_size_compiled_for_decode_kernels
(
self
.
model_config
.
num_attention_heads
//
self
.
tp_size
,
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
),
...
...
python/sglang/srt/memory_pool.py
View file @
476584cb
...
...
@@ -44,7 +44,7 @@ class ReqToTokenPool:
class
TokenToKVPool
:
"""A memory pool that maps a token to its kv cache locations"""
def
__init__
(
self
,
size
,
dtype
,
head_num
,
head_dim
,
layer_num
):
def
__init__
(
self
,
size
:
int
,
dtype
:
torch
.
dtype
,
head_num
:
int
,
head_dim
:
int
,
layer_num
:
int
):
self
.
size
=
size
# We also add one slot. This slot is used for writing dummy output from padded tokens.
...
...
@@ -63,16 +63,16 @@ class TokenToKVPool:
self
.
can_use_mem_size
=
self
.
size
self
.
clear
()
def
get_key_buffer
(
self
,
layer_id
):
def
get_key_buffer
(
self
,
layer_id
:
int
):
return
self
.
kv_data
[
layer_id
][:,
0
]
def
get_value_buffer
(
self
,
layer_id
):
def
get_value_buffer
(
self
,
layer_id
:
int
):
return
self
.
kv_data
[
layer_id
][:,
1
]
def
available_size
(
self
):
return
self
.
can_use_mem_size
+
len
(
self
.
prefetch_buffer
)
def
alloc
(
self
,
need_size
):
def
alloc
(
self
,
need_size
:
int
):
buffer_len
=
len
(
self
.
prefetch_buffer
)
if
need_size
<=
buffer_len
:
select_index
=
self
.
prefetch_buffer
[:
need_size
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment