Unverified Commit 173e0f70 authored by Junrong Lin's avatar Junrong Lin Committed by GitHub
Browse files

Enable memory saver for hybrid model (#11974)

parent f600866a
...@@ -154,12 +154,19 @@ class MambaPool: ...@@ -154,12 +154,19 @@ class MambaPool:
size: int, size: int,
cache_params: Union["Mamba2CacheParams", "KimiLinearCacheParams"], cache_params: Union["Mamba2CacheParams", "KimiLinearCacheParams"],
device: str, device: str,
enable_memory_saver: bool,
speculative_num_draft_tokens: Optional[int] = None, speculative_num_draft_tokens: Optional[int] = None,
): ):
conv_state_shape = cache_params.shape.conv conv_state_shape = cache_params.shape.conv
temporal_state_shape = cache_params.shape.temporal temporal_state_shape = cache_params.shape.temporal
conv_dtype = cache_params.dtype.conv conv_dtype = cache_params.dtype.conv
ssm_dtype = cache_params.dtype.temporal ssm_dtype = cache_params.dtype.temporal
self.size = size
self.device = device
self.free_slots = torch.arange(self.size, dtype=torch.int64, device=self.device)
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
)
num_mamba_layers = len(cache_params.layers) num_mamba_layers = len(cache_params.layers)
# for disagg with nvlink # for disagg with nvlink
...@@ -174,9 +181,8 @@ class MambaPool: ...@@ -174,9 +181,8 @@ class MambaPool:
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator()) self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
else: else:
self.custom_mem_pool = None self.custom_mem_pool = None
self.is_kda_cache = isinstance(cache_params, KimiLinearCacheParams) self.is_kda_cache = isinstance(cache_params, KimiLinearCacheParams)
with ( with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE), (
torch.cuda.use_mem_pool(self.custom_mem_pool) torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.enable_custom_mem_pool if self.enable_custom_mem_pool
else nullcontext() else nullcontext()
...@@ -270,11 +276,6 @@ class MambaPool: ...@@ -270,11 +276,6 @@ class MambaPool:
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, " f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB " f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
) )
self.size = size
self.device = device
self.free_slots = torch.arange(
self.size, dtype=torch.int64, device=self.device
)
self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB
self.num_mamba_layers = num_mamba_layers self.num_mamba_layers = num_mamba_layers
...@@ -369,6 +370,7 @@ class HybridReqToTokenPool(ReqToTokenPool): ...@@ -369,6 +370,7 @@ class HybridReqToTokenPool(ReqToTokenPool):
device=device, device=device,
enable_memory_saver=enable_memory_saver, enable_memory_saver=enable_memory_saver,
) )
self.enable_memory_saver = enable_memory_saver
self._init_mamba_pool( self._init_mamba_pool(
size=mamba_size, size=mamba_size,
cache_params=cache_params, cache_params=cache_params,
...@@ -387,6 +389,7 @@ class HybridReqToTokenPool(ReqToTokenPool): ...@@ -387,6 +389,7 @@ class HybridReqToTokenPool(ReqToTokenPool):
size=size, size=size,
cache_params=cache_params, cache_params=cache_params,
device=device, device=device,
enable_memory_saver=self.enable_memory_saver,
speculative_num_draft_tokens=speculative_num_draft_tokens, speculative_num_draft_tokens=speculative_num_draft_tokens,
) )
self.mamba_map = {layer_id: i for i, layer_id in enumerate(cache_params.layers)} self.mamba_map = {layer_id: i for i, layer_id in enumerate(cache_params.layers)}
...@@ -867,6 +870,7 @@ class HybridLinearKVPool(KVCache): ...@@ -867,6 +870,7 @@ class HybridLinearKVPool(KVCache):
full_attention_layer_ids: List[int], full_attention_layer_ids: List[int],
enable_kvcache_transpose: bool, enable_kvcache_transpose: bool,
device: str, device: str,
enable_memory_saver: bool,
mamba_pool: MambaPool, mamba_pool: MambaPool,
# TODO: refactor mla related args # TODO: refactor mla related args
use_mla: bool = False, use_mla: bool = False,
...@@ -899,7 +903,7 @@ class HybridLinearKVPool(KVCache): ...@@ -899,7 +903,7 @@ class HybridLinearKVPool(KVCache):
head_dim=head_dim, head_dim=head_dim,
layer_num=self.full_layer_nums, layer_num=self.full_layer_nums,
device=device, device=device,
enable_memory_saver=False, enable_memory_saver=enable_memory_saver,
) )
else: else:
TokenToKVPoolClass = MLATokenToKVPool TokenToKVPoolClass = MLATokenToKVPool
......
...@@ -1798,6 +1798,7 @@ class ModelRunner: ...@@ -1798,6 +1798,7 @@ class ModelRunner:
), ),
enable_kvcache_transpose=False, enable_kvcache_transpose=False,
device=self.device, device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
mamba_pool=self.req_to_token_pool.mamba_pool, mamba_pool=self.req_to_token_pool.mamba_pool,
use_mla=self.use_mla_backend, use_mla=self.use_mla_backend,
**extra_args, **extra_args,
......
...@@ -42,6 +42,7 @@ class TestMamba(unittest.TestCase): ...@@ -42,6 +42,7 @@ class TestMamba(unittest.TestCase):
full_attention_layer_ids=full_attention_layer_ids, full_attention_layer_ids=full_attention_layer_ids,
enable_kvcache_transpose=False, enable_kvcache_transpose=False,
device=device, device=device,
enable_memory_saver=False,
mamba_pool=None, mamba_pool=None,
) )
assert pool._transfer_full_attention_id(global_interval - 1) == 0 assert pool._transfer_full_attention_id(global_interval - 1) == 0
...@@ -174,6 +175,7 @@ class TestMamba(unittest.TestCase): ...@@ -174,6 +175,7 @@ class TestMamba(unittest.TestCase):
full_attention_layer_ids=full_attention_layer_ids, full_attention_layer_ids=full_attention_layer_ids,
enable_kvcache_transpose=False, enable_kvcache_transpose=False,
device=device, device=device,
enable_memory_saver=False,
mamba_pool=req_to_token_pool.mamba_pool, mamba_pool=req_to_token_pool.mamba_pool,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment