Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
173e0f70
"include/device.hpp" did not exist on "88b77181aab1198b41b612f6d03b6dfb2d32bd40"
Unverified
Commit
173e0f70
authored
Nov 04, 2025
by
Junrong Lin
Committed by
GitHub
Nov 04, 2025
Browse files
Enable memory saver for hybrid model (#11974)
parent
f600866a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
8 deletions
+15
-8
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+12
-8
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+1
-0
test/srt/test_mamba_unittest.py
test/srt/test_mamba_unittest.py
+2
-0
No files found.
python/sglang/srt/mem_cache/memory_pool.py
View file @
173e0f70
...
...
@@ -154,12 +154,19 @@ class MambaPool:
size
:
int
,
cache_params
:
Union
[
"Mamba2CacheParams"
,
"KimiLinearCacheParams"
],
device
:
str
,
enable_memory_saver
:
bool
,
speculative_num_draft_tokens
:
Optional
[
int
]
=
None
,
):
conv_state_shape
=
cache_params
.
shape
.
conv
temporal_state_shape
=
cache_params
.
shape
.
temporal
conv_dtype
=
cache_params
.
dtype
.
conv
ssm_dtype
=
cache_params
.
dtype
.
temporal
self
.
size
=
size
self
.
device
=
device
self
.
free_slots
=
torch
.
arange
(
self
.
size
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
enable
=
enable_memory_saver
)
num_mamba_layers
=
len
(
cache_params
.
layers
)
# for disagg with nvlink
...
...
@@ -174,9 +181,8 @@ class MambaPool:
self
.
custom_mem_pool
=
torch
.
cuda
.
MemPool
(
allocator
.
allocator
())
else
:
self
.
custom_mem_pool
=
None
self
.
is_kda_cache
=
isinstance
(
cache_params
,
KimiLinearCacheParams
)
with
(
with
self
.
memory_saver_adapter
.
region
(
GPU_MEMORY_TYPE_KV_CACHE
),
(
torch
.
cuda
.
use_mem_pool
(
self
.
custom_mem_pool
)
if
self
.
enable_custom_mem_pool
else
nullcontext
()
...
...
@@ -270,11 +276,6 @@ class MambaPool:
f
"conv_state size:
{
get_tensor_size_bytes
(
conv_state
)
/
GB
:.
2
f
}
GB, "
f
"ssm_state size:
{
get_tensor_size_bytes
(
temporal_state
)
/
GB
:.
2
f
}
GB "
)
self
.
size
=
size
self
.
device
=
device
self
.
free_slots
=
torch
.
arange
(
self
.
size
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
mem_usage
=
self
.
mamba_cache
.
mem_usage_bytes
()
/
GB
self
.
num_mamba_layers
=
num_mamba_layers
...
...
@@ -369,6 +370,7 @@ class HybridReqToTokenPool(ReqToTokenPool):
device
=
device
,
enable_memory_saver
=
enable_memory_saver
,
)
self
.
enable_memory_saver
=
enable_memory_saver
self
.
_init_mamba_pool
(
size
=
mamba_size
,
cache_params
=
cache_params
,
...
...
@@ -387,6 +389,7 @@ class HybridReqToTokenPool(ReqToTokenPool):
size
=
size
,
cache_params
=
cache_params
,
device
=
device
,
enable_memory_saver
=
self
.
enable_memory_saver
,
speculative_num_draft_tokens
=
speculative_num_draft_tokens
,
)
self
.
mamba_map
=
{
layer_id
:
i
for
i
,
layer_id
in
enumerate
(
cache_params
.
layers
)}
...
...
@@ -867,6 +870,7 @@ class HybridLinearKVPool(KVCache):
full_attention_layer_ids
:
List
[
int
],
enable_kvcache_transpose
:
bool
,
device
:
str
,
enable_memory_saver
:
bool
,
mamba_pool
:
MambaPool
,
# TODO: refactor mla related args
use_mla
:
bool
=
False
,
...
...
@@ -899,7 +903,7 @@ class HybridLinearKVPool(KVCache):
head_dim
=
head_dim
,
layer_num
=
self
.
full_layer_nums
,
device
=
device
,
enable_memory_saver
=
False
,
enable_memory_saver
=
enable_memory_saver
,
)
else
:
TokenToKVPoolClass
=
MLATokenToKVPool
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
173e0f70
...
...
@@ -1798,6 +1798,7 @@ class ModelRunner:
),
enable_kvcache_transpose
=
False
,
device
=
self
.
device
,
enable_memory_saver
=
self
.
server_args
.
enable_memory_saver
,
mamba_pool
=
self
.
req_to_token_pool
.
mamba_pool
,
use_mla
=
self
.
use_mla_backend
,
**
extra_args
,
...
...
test/srt/test_mamba_unittest.py
View file @
173e0f70
...
...
@@ -42,6 +42,7 @@ class TestMamba(unittest.TestCase):
full_attention_layer_ids
=
full_attention_layer_ids
,
enable_kvcache_transpose
=
False
,
device
=
device
,
enable_memory_saver
=
False
,
mamba_pool
=
None
,
)
assert
pool
.
_transfer_full_attention_id
(
global_interval
-
1
)
==
0
...
...
@@ -174,6 +175,7 @@ class TestMamba(unittest.TestCase):
full_attention_layer_ids
=
full_attention_layer_ids
,
enable_kvcache_transpose
=
False
,
device
=
device
,
enable_memory_saver
=
False
,
mamba_pool
=
req_to_token_pool
.
mamba_pool
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment