Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d22d0447
"tools/vscode:/vscode.git/clone" did not exist on "e682fa74bad92449e86cc01e368d504b7fd5665e"
Unverified
Commit
d22d0447
authored
Nov 04, 2025
by
Baizhou Zhang
Committed by
GitHub
Nov 04, 2025
Browse files
Revert "Enable memory saver for hybrid model" (#12648)
parent
887742a1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
8 additions
and
15 deletions
+8
-15
python/sglang/srt/mem_cache/memory_pool.py
python/sglang/srt/mem_cache/memory_pool.py
+8
-12
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+0
-1
test/srt/test_mamba_unittest.py
test/srt/test_mamba_unittest.py
+0
-2
No files found.
python/sglang/srt/mem_cache/memory_pool.py
View file @
d22d0447
...
@@ -154,19 +154,12 @@ class MambaPool:
...
@@ -154,19 +154,12 @@ class MambaPool:
size
:
int
,
size
:
int
,
cache_params
:
Union
[
"Mamba2CacheParams"
,
"KimiLinearCacheParams"
],
cache_params
:
Union
[
"Mamba2CacheParams"
,
"KimiLinearCacheParams"
],
device
:
str
,
device
:
str
,
enable_memory_saver
:
bool
,
speculative_num_draft_tokens
:
Optional
[
int
]
=
None
,
speculative_num_draft_tokens
:
Optional
[
int
]
=
None
,
):
):
conv_state_shape
=
cache_params
.
shape
.
conv
conv_state_shape
=
cache_params
.
shape
.
conv
temporal_state_shape
=
cache_params
.
shape
.
temporal
temporal_state_shape
=
cache_params
.
shape
.
temporal
conv_dtype
=
cache_params
.
dtype
.
conv
conv_dtype
=
cache_params
.
dtype
.
conv
ssm_dtype
=
cache_params
.
dtype
.
temporal
ssm_dtype
=
cache_params
.
dtype
.
temporal
self
.
size
=
size
self
.
device
=
device
self
.
free_slots
=
torch
.
arange
(
self
.
size
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
memory_saver_adapter
=
TorchMemorySaverAdapter
.
create
(
enable
=
enable_memory_saver
)
num_mamba_layers
=
len
(
cache_params
.
layers
)
num_mamba_layers
=
len
(
cache_params
.
layers
)
# for disagg with nvlink
# for disagg with nvlink
...
@@ -181,8 +174,9 @@ class MambaPool:
...
@@ -181,8 +174,9 @@ class MambaPool:
self
.
custom_mem_pool
=
torch
.
cuda
.
MemPool
(
allocator
.
allocator
())
self
.
custom_mem_pool
=
torch
.
cuda
.
MemPool
(
allocator
.
allocator
())
else
:
else
:
self
.
custom_mem_pool
=
None
self
.
custom_mem_pool
=
None
self
.
is_kda_cache
=
isinstance
(
cache_params
,
KimiLinearCacheParams
)
self
.
is_kda_cache
=
isinstance
(
cache_params
,
KimiLinearCacheParams
)
with
self
.
memory_saver_adapter
.
region
(
GPU_MEMORY_TYPE_KV_CACHE
),
(
with
(
torch
.
cuda
.
use_mem_pool
(
self
.
custom_mem_pool
)
torch
.
cuda
.
use_mem_pool
(
self
.
custom_mem_pool
)
if
self
.
enable_custom_mem_pool
if
self
.
enable_custom_mem_pool
else
nullcontext
()
else
nullcontext
()
...
@@ -276,6 +270,11 @@ class MambaPool:
...
@@ -276,6 +270,11 @@ class MambaPool:
f
"conv_state size:
{
get_tensor_size_bytes
(
conv_state
)
/
GB
:.
2
f
}
GB, "
f
"conv_state size:
{
get_tensor_size_bytes
(
conv_state
)
/
GB
:.
2
f
}
GB, "
f
"ssm_state size:
{
get_tensor_size_bytes
(
temporal_state
)
/
GB
:.
2
f
}
GB "
f
"ssm_state size:
{
get_tensor_size_bytes
(
temporal_state
)
/
GB
:.
2
f
}
GB "
)
)
self
.
size
=
size
self
.
device
=
device
self
.
free_slots
=
torch
.
arange
(
self
.
size
,
dtype
=
torch
.
int64
,
device
=
self
.
device
)
self
.
mem_usage
=
self
.
mamba_cache
.
mem_usage_bytes
()
/
GB
self
.
mem_usage
=
self
.
mamba_cache
.
mem_usage_bytes
()
/
GB
self
.
num_mamba_layers
=
num_mamba_layers
self
.
num_mamba_layers
=
num_mamba_layers
...
@@ -370,7 +369,6 @@ class HybridReqToTokenPool(ReqToTokenPool):
...
@@ -370,7 +369,6 @@ class HybridReqToTokenPool(ReqToTokenPool):
device
=
device
,
device
=
device
,
enable_memory_saver
=
enable_memory_saver
,
enable_memory_saver
=
enable_memory_saver
,
)
)
self
.
enable_memory_saver
=
enable_memory_saver
self
.
_init_mamba_pool
(
self
.
_init_mamba_pool
(
size
=
mamba_size
,
size
=
mamba_size
,
cache_params
=
cache_params
,
cache_params
=
cache_params
,
...
@@ -389,7 +387,6 @@ class HybridReqToTokenPool(ReqToTokenPool):
...
@@ -389,7 +387,6 @@ class HybridReqToTokenPool(ReqToTokenPool):
size
=
size
,
size
=
size
,
cache_params
=
cache_params
,
cache_params
=
cache_params
,
device
=
device
,
device
=
device
,
enable_memory_saver
=
self
.
enable_memory_saver
,
speculative_num_draft_tokens
=
speculative_num_draft_tokens
,
speculative_num_draft_tokens
=
speculative_num_draft_tokens
,
)
)
self
.
mamba_map
=
{
layer_id
:
i
for
i
,
layer_id
in
enumerate
(
cache_params
.
layers
)}
self
.
mamba_map
=
{
layer_id
:
i
for
i
,
layer_id
in
enumerate
(
cache_params
.
layers
)}
...
@@ -870,7 +867,6 @@ class HybridLinearKVPool(KVCache):
...
@@ -870,7 +867,6 @@ class HybridLinearKVPool(KVCache):
full_attention_layer_ids
:
List
[
int
],
full_attention_layer_ids
:
List
[
int
],
enable_kvcache_transpose
:
bool
,
enable_kvcache_transpose
:
bool
,
device
:
str
,
device
:
str
,
enable_memory_saver
:
bool
,
mamba_pool
:
MambaPool
,
mamba_pool
:
MambaPool
,
# TODO: refactor mla related args
# TODO: refactor mla related args
use_mla
:
bool
=
False
,
use_mla
:
bool
=
False
,
...
@@ -903,7 +899,7 @@ class HybridLinearKVPool(KVCache):
...
@@ -903,7 +899,7 @@ class HybridLinearKVPool(KVCache):
head_dim
=
head_dim
,
head_dim
=
head_dim
,
layer_num
=
self
.
full_layer_nums
,
layer_num
=
self
.
full_layer_nums
,
device
=
device
,
device
=
device
,
enable_memory_saver
=
enable_memory_saver
,
enable_memory_saver
=
False
,
)
)
else
:
else
:
TokenToKVPoolClass
=
MLATokenToKVPool
TokenToKVPoolClass
=
MLATokenToKVPool
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
d22d0447
...
@@ -1798,7 +1798,6 @@ class ModelRunner:
...
@@ -1798,7 +1798,6 @@ class ModelRunner:
),
),
enable_kvcache_transpose
=
False
,
enable_kvcache_transpose
=
False
,
device
=
self
.
device
,
device
=
self
.
device
,
enable_memory_saver
=
self
.
server_args
.
enable_memory_saver
,
mamba_pool
=
self
.
req_to_token_pool
.
mamba_pool
,
mamba_pool
=
self
.
req_to_token_pool
.
mamba_pool
,
use_mla
=
self
.
use_mla_backend
,
use_mla
=
self
.
use_mla_backend
,
**
extra_args
,
**
extra_args
,
...
...
test/srt/test_mamba_unittest.py
View file @
d22d0447
...
@@ -42,7 +42,6 @@ class TestMamba(unittest.TestCase):
...
@@ -42,7 +42,6 @@ class TestMamba(unittest.TestCase):
full_attention_layer_ids
=
full_attention_layer_ids
,
full_attention_layer_ids
=
full_attention_layer_ids
,
enable_kvcache_transpose
=
False
,
enable_kvcache_transpose
=
False
,
device
=
device
,
device
=
device
,
enable_memory_saver
=
False
,
mamba_pool
=
None
,
mamba_pool
=
None
,
)
)
assert
pool
.
_transfer_full_attention_id
(
global_interval
-
1
)
==
0
assert
pool
.
_transfer_full_attention_id
(
global_interval
-
1
)
==
0
...
@@ -175,7 +174,6 @@ class TestMamba(unittest.TestCase):
...
@@ -175,7 +174,6 @@ class TestMamba(unittest.TestCase):
full_attention_layer_ids
=
full_attention_layer_ids
,
full_attention_layer_ids
=
full_attention_layer_ids
,
enable_kvcache_transpose
=
False
,
enable_kvcache_transpose
=
False
,
device
=
device
,
device
=
device
,
enable_memory_saver
=
False
,
mamba_pool
=
req_to_token_pool
.
mamba_pool
,
mamba_pool
=
req_to_token_pool
.
mamba_pool
,
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment