Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
476584cb
"docs/vscode:/vscode.git/clone" did not exist on "d7067e4430f4e99697b382b0e3ea5597af737a2c"
Unverified
Commit
476584cb
authored
Jul 17, 2024
by
Ying Sheng
Committed by
GitHub
Jul 17, 2024
Browse files
Increase the capacity of the memory pool (#643)
parent
abd5385a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
18 additions
and
16 deletions
+18
-16
python/sglang/srt/managers/controller/cuda_graph_runner.py
python/sglang/srt/managers/controller/cuda_graph_runner.py
+2
-4
python/sglang/srt/managers/controller/infer_batch.py
python/sglang/srt/managers/controller/infer_batch.py
+5
-0
python/sglang/srt/managers/controller/model_runner.py
python/sglang/srt/managers/controller/model_runner.py
+7
-8
python/sglang/srt/memory_pool.py
python/sglang/srt/memory_pool.py
+4
-4
No files found.
python/sglang/srt/managers/controller/cuda_graph_runner.py
View file @
476584cb
...
@@ -3,9 +3,10 @@
...
@@ -3,9 +3,10 @@
import
bisect
import
bisect
import
torch
import
torch
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
from
vllm.distributed.parallel_state
import
graph_capture
from
vllm.distributed.parallel_state
import
graph_capture
from
sglang.global_config
import
global_config
from
sglang.srt.layers.logits_processor
import
LogitProcessorOutput
from
sglang.srt.layers.logits_processor
import
LogitProcessorOutput
from
sglang.srt.managers.controller.infer_batch
import
(
from
sglang.srt.managers.controller.infer_batch
import
(
Batch
,
Batch
,
...
@@ -74,9 +75,6 @@ class CudaGraphRunner:
...
@@ -74,9 +75,6 @@ class CudaGraphRunner:
self
.
flashinfer_handlers
[
bs
]
=
flashinfer_handler
self
.
flashinfer_handlers
[
bs
]
=
flashinfer_handler
def
capture_one_batch_size
(
self
,
bs
):
def
capture_one_batch_size
(
self
,
bs
):
from
flashinfer
import
BatchDecodeWithPagedKVCacheWrapper
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
graph
=
torch
.
cuda
.
CUDAGraph
()
graph
=
torch
.
cuda
.
CUDAGraph
()
stream
=
self
.
stream
stream
=
self
.
stream
...
...
python/sglang/srt/managers/controller/infer_batch.py
View file @
476584cb
...
@@ -325,6 +325,11 @@ class Batch:
...
@@ -325,6 +325,11 @@ class Batch:
seq_lens
=
[]
seq_lens
=
[]
req_pool_indices
=
self
.
req_to_token_pool
.
alloc
(
bs
)
req_pool_indices
=
self
.
req_to_token_pool
.
alloc
(
bs
)
if
req_pool_indices
is
None
:
raise
RuntimeError
(
"Out of memory. "
"Please set a smaller number for `--max-running-requests`."
)
req_pool_indices_cpu
=
req_pool_indices
.
cpu
().
numpy
()
req_pool_indices_cpu
=
req_pool_indices
.
cpu
().
numpy
()
for
i
in
range
(
bs
):
for
i
in
range
(
bs
):
flatten_input_ids
.
extend
(
input_ids
[
i
])
flatten_input_ids
.
extend
(
input_ids
[
i
])
...
...
python/sglang/srt/managers/controller/model_runner.py
View file @
476584cb
...
@@ -9,6 +9,12 @@ from typing import Optional, Type
...
@@ -9,6 +9,12 @@ from typing import Optional, Type
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
flashinfer
import
(
BatchDecodeWithPagedKVCacheWrapper
,
BatchPrefillWithPagedKVCacheWrapper
,
BatchPrefillWithRaggedKVCacheWrapper
,
)
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
from
vllm.config
import
DeviceConfig
,
LoadConfig
from
vllm.config
import
DeviceConfig
,
LoadConfig
from
vllm.config
import
ModelConfig
as
VllmModelConfig
from
vllm.config
import
ModelConfig
as
VllmModelConfig
from
vllm.distributed
import
(
from
vllm.distributed
import
(
...
@@ -162,7 +168,7 @@ class ModelRunner:
...
@@ -162,7 +168,7 @@ class ModelRunner:
)
)
self
.
req_to_token_pool
=
ReqToTokenPool
(
self
.
req_to_token_pool
=
ReqToTokenPool
(
int
(
self
.
max_total_num_tokens
/
self
.
model_config
.
context_len
*
256
),
max
(
int
(
self
.
max_total_num_tokens
/
self
.
model_config
.
context_len
*
512
),
2048
),
self
.
model_config
.
context_len
+
8
,
self
.
model_config
.
context_len
+
8
,
)
)
self
.
token_to_kv_pool
=
TokenToKVPool
(
self
.
token_to_kv_pool
=
TokenToKVPool
(
...
@@ -193,13 +199,6 @@ class ModelRunner:
...
@@ -193,13 +199,6 @@ class ModelRunner:
self
.
flashinfer_decode_wrapper
=
None
self
.
flashinfer_decode_wrapper
=
None
return
return
from
flashinfer
import
(
BatchDecodeWithPagedKVCacheWrapper
,
BatchPrefillWithPagedKVCacheWrapper
,
BatchPrefillWithRaggedKVCacheWrapper
,
)
from
flashinfer.decode
import
_grouped_size_compiled_for_decode_kernels
if
not
_grouped_size_compiled_for_decode_kernels
(
if
not
_grouped_size_compiled_for_decode_kernels
(
self
.
model_config
.
num_attention_heads
//
self
.
tp_size
,
self
.
model_config
.
num_attention_heads
//
self
.
tp_size
,
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
),
self
.
model_config
.
get_num_kv_heads
(
self
.
tp_size
),
...
...
python/sglang/srt/memory_pool.py
View file @
476584cb
...
@@ -44,7 +44,7 @@ class ReqToTokenPool:
...
@@ -44,7 +44,7 @@ class ReqToTokenPool:
class
TokenToKVPool
:
class
TokenToKVPool
:
"""A memory pool that maps a token to its kv cache locations"""
"""A memory pool that maps a token to its kv cache locations"""
def
__init__
(
self
,
size
,
dtype
,
head_num
,
head_dim
,
layer_num
):
def
__init__
(
self
,
size
:
int
,
dtype
:
torch
.
dtype
,
head_num
:
int
,
head_dim
:
int
,
layer_num
:
int
):
self
.
size
=
size
self
.
size
=
size
# We also add one slot. This slot is used for writing dummy output from padded tokens.
# We also add one slot. This slot is used for writing dummy output from padded tokens.
...
@@ -63,16 +63,16 @@ class TokenToKVPool:
...
@@ -63,16 +63,16 @@ class TokenToKVPool:
self
.
can_use_mem_size
=
self
.
size
self
.
can_use_mem_size
=
self
.
size
self
.
clear
()
self
.
clear
()
def
get_key_buffer
(
self
,
layer_id
):
def
get_key_buffer
(
self
,
layer_id
:
int
):
return
self
.
kv_data
[
layer_id
][:,
0
]
return
self
.
kv_data
[
layer_id
][:,
0
]
def
get_value_buffer
(
self
,
layer_id
):
def
get_value_buffer
(
self
,
layer_id
:
int
):
return
self
.
kv_data
[
layer_id
][:,
1
]
return
self
.
kv_data
[
layer_id
][:,
1
]
def
available_size
(
self
):
def
available_size
(
self
):
return
self
.
can_use_mem_size
+
len
(
self
.
prefetch_buffer
)
return
self
.
can_use_mem_size
+
len
(
self
.
prefetch_buffer
)
def
alloc
(
self
,
need_size
):
def
alloc
(
self
,
need_size
:
int
):
buffer_len
=
len
(
self
.
prefetch_buffer
)
buffer_len
=
len
(
self
.
prefetch_buffer
)
if
need_size
<=
buffer_len
:
if
need_size
<=
buffer_len
:
select_index
=
self
.
prefetch_buffer
[:
need_size
]
select_index
=
self
.
prefetch_buffer
[:
need_size
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment