Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f2887498
Unverified
Commit
f2887498
authored
Aug 10, 2025
by
Lianmin Zheng
Committed by
GitHub
Aug 10, 2025
Browse files
Simplify memory pool (#9033)
parent
8ecf6b9d
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
35 additions
and
35 deletions
+35
-35
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+35
-35
No files found.
python/sglang/srt/model_executor/model_runner.py
View file @
f2887498
...
...
@@ -1251,30 +1251,33 @@ class ModelRunner:
# Draft worker shares req_to_token_pool with the target worker.
assert
self
.
is_draft_worker
if
self
.
server_args
.
attention_backend
==
"ascend"
and
not
self
.
use_mla_backend
:
self
.
token_to_kv_pool
=
AscendTokenToKVPool
(
self
.
max_total_num_tokens
,
page_size
=
self
.
page_size
,
dtype
=
self
.
kv_cache_dtype
,
head_num
=
self
.
model_config
.
get_num_kv_heads
(
get_attention_tp_size
()),
head_dim
=
self
.
model_config
.
head_dim
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
device
=
self
.
device
,
enable_memory_saver
=
self
.
server_args
.
enable_memory_saver
,
)
elif
self
.
server_args
.
attention_backend
==
"ascend"
and
self
.
use_mla_backend
:
self
.
token_to_kv_pool
=
AscendMLAPagedTokenToKVPool
(
self
.
max_total_num_tokens
,
page_size
=
self
.
page_size
,
dtype
=
self
.
kv_cache_dtype
,
kv_lora_rank
=
self
.
model_config
.
kv_lora_rank
,
qk_rope_head_dim
=
self
.
model_config
.
qk_rope_head_dim
,
layer_num
=
self
.
num_effective_layers
,
device
=
self
.
device
,
enable_memory_saver
=
self
.
server_args
.
enable_memory_saver
,
start_layer
=
self
.
start_layer
,
end_layer
=
self
.
end_layer
,
)
if
self
.
server_args
.
attention_backend
==
"ascend"
:
if
self
.
use_mla_backend
:
self
.
token_to_kv_pool
=
AscendMLAPagedTokenToKVPool
(
self
.
max_total_num_tokens
,
page_size
=
self
.
page_size
,
dtype
=
self
.
kv_cache_dtype
,
kv_lora_rank
=
self
.
model_config
.
kv_lora_rank
,
qk_rope_head_dim
=
self
.
model_config
.
qk_rope_head_dim
,
layer_num
=
self
.
num_effective_layers
,
device
=
self
.
device
,
enable_memory_saver
=
self
.
server_args
.
enable_memory_saver
,
start_layer
=
self
.
start_layer
,
end_layer
=
self
.
end_layer
,
)
else
:
self
.
token_to_kv_pool
=
AscendTokenToKVPool
(
self
.
max_total_num_tokens
,
page_size
=
self
.
page_size
,
dtype
=
self
.
kv_cache_dtype
,
head_num
=
self
.
model_config
.
get_num_kv_heads
(
get_attention_tp_size
()
),
head_dim
=
self
.
model_config
.
head_dim
,
layer_num
=
self
.
model_config
.
num_hidden_layers
,
device
=
self
.
device
,
enable_memory_saver
=
self
.
server_args
.
enable_memory_saver
,
)
elif
self
.
use_mla_backend
:
self
.
token_to_kv_pool
=
MLATokenToKVPool
(
self
.
max_total_num_tokens
,
...
...
@@ -1333,6 +1336,7 @@ class ModelRunner:
end_layer
=
self
.
end_layer
,
)
need_sort
=
self
.
server_args
.
disaggregation_mode
in
(
"decode"
,
"prefill"
)
if
self
.
token_to_kv_pool_allocator
is
None
:
if
self
.
page_size
==
1
:
if
self
.
is_hybrid
:
...
...
@@ -1342,8 +1346,7 @@ class ModelRunner:
dtype
=
self
.
kv_cache_dtype
,
device
=
self
.
device
,
kvcache
=
self
.
token_to_kv_pool
,
need_sort
=
self
.
server_args
.
disaggregation_mode
in
(
"decode"
,
"prefill"
),
need_sort
=
need_sort
,
)
else
:
self
.
token_to_kv_pool_allocator
=
TokenToKVPoolAllocator
(
...
...
@@ -1351,29 +1354,26 @@ class ModelRunner:
dtype
=
self
.
kv_cache_dtype
,
device
=
self
.
device
,
kvcache
=
self
.
token_to_kv_pool
,
need_sort
=
self
.
server_args
.
disaggregation_mode
in
(
"decode"
,
"prefill"
),
need_sort
=
need_sort
,
)
else
:
if
_is_npu
:
self
.
token_to_kv_pool_allocator
=
Ascend
PagedTokenToKVPoolAllocator
(
if
not
_is_npu
:
self
.
token_to_kv_pool_allocator
=
PagedTokenToKVPoolAllocator
(
self
.
max_total_num_tokens
,
page_size
=
self
.
page_size
,
dtype
=
self
.
kv_cache_dtype
,
device
=
self
.
device
,
kvcache
=
self
.
token_to_kv_pool
,
need_sort
=
self
.
server_args
.
disaggregation_mode
in
(
"decode"
,
"prefill"
),
need_sort
=
need_sort
,
)
else
:
self
.
token_to_kv_pool_allocator
=
PagedTokenToKVPoolAllocator
(
self
.
token_to_kv_pool_allocator
=
Ascend
PagedTokenToKVPoolAllocator
(
self
.
max_total_num_tokens
,
page_size
=
self
.
page_size
,
dtype
=
self
.
kv_cache_dtype
,
device
=
self
.
device
,
kvcache
=
self
.
token_to_kv_pool
,
need_sort
=
self
.
server_args
.
disaggregation_mode
in
(
"decode"
,
"prefill"
),
need_sort
=
need_sort
,
)
else
:
assert
self
.
is_draft_worker
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment