Unverified Commit 173e0f70 authored by Junrong Lin's avatar Junrong Lin Committed by GitHub
Browse files

Enable memory saver for hybrid model (#11974)

parent f600866a
......@@ -154,12 +154,19 @@ class MambaPool:
size: int,
cache_params: Union["Mamba2CacheParams", "KimiLinearCacheParams"],
device: str,
enable_memory_saver: bool,
speculative_num_draft_tokens: Optional[int] = None,
):
conv_state_shape = cache_params.shape.conv
temporal_state_shape = cache_params.shape.temporal
conv_dtype = cache_params.dtype.conv
ssm_dtype = cache_params.dtype.temporal
self.size = size
self.device = device
self.free_slots = torch.arange(self.size, dtype=torch.int64, device=self.device)
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
)
num_mamba_layers = len(cache_params.layers)
# for disagg with nvlink
......@@ -174,9 +181,8 @@ class MambaPool:
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
else:
self.custom_mem_pool = None
self.is_kda_cache = isinstance(cache_params, KimiLinearCacheParams)
with (
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE), (
torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.enable_custom_mem_pool
else nullcontext()
......@@ -270,11 +276,6 @@ class MambaPool:
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
)
self.size = size
self.device = device
self.free_slots = torch.arange(
self.size, dtype=torch.int64, device=self.device
)
self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB
self.num_mamba_layers = num_mamba_layers
......@@ -369,6 +370,7 @@ class HybridReqToTokenPool(ReqToTokenPool):
device=device,
enable_memory_saver=enable_memory_saver,
)
self.enable_memory_saver = enable_memory_saver
self._init_mamba_pool(
size=mamba_size,
cache_params=cache_params,
......@@ -387,6 +389,7 @@ class HybridReqToTokenPool(ReqToTokenPool):
size=size,
cache_params=cache_params,
device=device,
enable_memory_saver=self.enable_memory_saver,
speculative_num_draft_tokens=speculative_num_draft_tokens,
)
self.mamba_map = {layer_id: i for i, layer_id in enumerate(cache_params.layers)}
......@@ -867,6 +870,7 @@ class HybridLinearKVPool(KVCache):
full_attention_layer_ids: List[int],
enable_kvcache_transpose: bool,
device: str,
enable_memory_saver: bool,
mamba_pool: MambaPool,
# TODO: refactor mla related args
use_mla: bool = False,
......@@ -899,7 +903,7 @@ class HybridLinearKVPool(KVCache):
head_dim=head_dim,
layer_num=self.full_layer_nums,
device=device,
enable_memory_saver=False,
enable_memory_saver=enable_memory_saver,
)
else:
TokenToKVPoolClass = MLATokenToKVPool
......
......@@ -1798,6 +1798,7 @@ class ModelRunner:
),
enable_kvcache_transpose=False,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
mamba_pool=self.req_to_token_pool.mamba_pool,
use_mla=self.use_mla_backend,
**extra_args,
......
......@@ -42,6 +42,7 @@ class TestMamba(unittest.TestCase):
full_attention_layer_ids=full_attention_layer_ids,
enable_kvcache_transpose=False,
device=device,
enable_memory_saver=False,
mamba_pool=None,
)
assert pool._transfer_full_attention_id(global_interval - 1) == 0
......@@ -174,6 +175,7 @@ class TestMamba(unittest.TestCase):
full_attention_layer_ids=full_attention_layer_ids,
enable_kvcache_transpose=False,
device=device,
enable_memory_saver=False,
mamba_pool=req_to_token_pool.mamba_pool,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment