"tools/vscode:/vscode.git/clone" did not exist on "e682fa74bad92449e86cc01e368d504b7fd5665e"
Unverified Commit d22d0447 authored by Baizhou Zhang's avatar Baizhou Zhang Committed by GitHub
Browse files

Revert "Enable memory saver for hybrid model" (#12648)

parent 887742a1
...@@ -154,19 +154,12 @@ class MambaPool: ...@@ -154,19 +154,12 @@ class MambaPool:
size: int, size: int,
cache_params: Union["Mamba2CacheParams", "KimiLinearCacheParams"], cache_params: Union["Mamba2CacheParams", "KimiLinearCacheParams"],
device: str, device: str,
enable_memory_saver: bool,
speculative_num_draft_tokens: Optional[int] = None, speculative_num_draft_tokens: Optional[int] = None,
): ):
conv_state_shape = cache_params.shape.conv conv_state_shape = cache_params.shape.conv
temporal_state_shape = cache_params.shape.temporal temporal_state_shape = cache_params.shape.temporal
conv_dtype = cache_params.dtype.conv conv_dtype = cache_params.dtype.conv
ssm_dtype = cache_params.dtype.temporal ssm_dtype = cache_params.dtype.temporal
self.size = size
self.device = device
self.free_slots = torch.arange(self.size, dtype=torch.int64, device=self.device)
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
enable=enable_memory_saver
)
num_mamba_layers = len(cache_params.layers) num_mamba_layers = len(cache_params.layers)
# for disagg with nvlink # for disagg with nvlink
...@@ -181,8 +174,9 @@ class MambaPool: ...@@ -181,8 +174,9 @@ class MambaPool:
self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator()) self.custom_mem_pool = torch.cuda.MemPool(allocator.allocator())
else: else:
self.custom_mem_pool = None self.custom_mem_pool = None
self.is_kda_cache = isinstance(cache_params, KimiLinearCacheParams) self.is_kda_cache = isinstance(cache_params, KimiLinearCacheParams)
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE), ( with (
torch.cuda.use_mem_pool(self.custom_mem_pool) torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.enable_custom_mem_pool if self.enable_custom_mem_pool
else nullcontext() else nullcontext()
...@@ -276,6 +270,11 @@ class MambaPool: ...@@ -276,6 +270,11 @@ class MambaPool:
f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, " f"conv_state size: {get_tensor_size_bytes(conv_state) / GB:.2f}GB, "
f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB " f"ssm_state size: {get_tensor_size_bytes(temporal_state) / GB:.2f}GB "
) )
self.size = size
self.device = device
self.free_slots = torch.arange(
self.size, dtype=torch.int64, device=self.device
)
self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB self.mem_usage = self.mamba_cache.mem_usage_bytes() / GB
self.num_mamba_layers = num_mamba_layers self.num_mamba_layers = num_mamba_layers
...@@ -370,7 +369,6 @@ class HybridReqToTokenPool(ReqToTokenPool): ...@@ -370,7 +369,6 @@ class HybridReqToTokenPool(ReqToTokenPool):
device=device, device=device,
enable_memory_saver=enable_memory_saver, enable_memory_saver=enable_memory_saver,
) )
self.enable_memory_saver = enable_memory_saver
self._init_mamba_pool( self._init_mamba_pool(
size=mamba_size, size=mamba_size,
cache_params=cache_params, cache_params=cache_params,
...@@ -389,7 +387,6 @@ class HybridReqToTokenPool(ReqToTokenPool): ...@@ -389,7 +387,6 @@ class HybridReqToTokenPool(ReqToTokenPool):
size=size, size=size,
cache_params=cache_params, cache_params=cache_params,
device=device, device=device,
enable_memory_saver=self.enable_memory_saver,
speculative_num_draft_tokens=speculative_num_draft_tokens, speculative_num_draft_tokens=speculative_num_draft_tokens,
) )
self.mamba_map = {layer_id: i for i, layer_id in enumerate(cache_params.layers)} self.mamba_map = {layer_id: i for i, layer_id in enumerate(cache_params.layers)}
...@@ -870,7 +867,6 @@ class HybridLinearKVPool(KVCache): ...@@ -870,7 +867,6 @@ class HybridLinearKVPool(KVCache):
full_attention_layer_ids: List[int], full_attention_layer_ids: List[int],
enable_kvcache_transpose: bool, enable_kvcache_transpose: bool,
device: str, device: str,
enable_memory_saver: bool,
mamba_pool: MambaPool, mamba_pool: MambaPool,
# TODO: refactor mla related args # TODO: refactor mla related args
use_mla: bool = False, use_mla: bool = False,
...@@ -903,7 +899,7 @@ class HybridLinearKVPool(KVCache): ...@@ -903,7 +899,7 @@ class HybridLinearKVPool(KVCache):
head_dim=head_dim, head_dim=head_dim,
layer_num=self.full_layer_nums, layer_num=self.full_layer_nums,
device=device, device=device,
enable_memory_saver=enable_memory_saver, enable_memory_saver=False,
) )
else: else:
TokenToKVPoolClass = MLATokenToKVPool TokenToKVPoolClass = MLATokenToKVPool
......
...@@ -1798,7 +1798,6 @@ class ModelRunner: ...@@ -1798,7 +1798,6 @@ class ModelRunner:
), ),
enable_kvcache_transpose=False, enable_kvcache_transpose=False,
device=self.device, device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
mamba_pool=self.req_to_token_pool.mamba_pool, mamba_pool=self.req_to_token_pool.mamba_pool,
use_mla=self.use_mla_backend, use_mla=self.use_mla_backend,
**extra_args, **extra_args,
......
...@@ -42,7 +42,6 @@ class TestMamba(unittest.TestCase): ...@@ -42,7 +42,6 @@ class TestMamba(unittest.TestCase):
full_attention_layer_ids=full_attention_layer_ids, full_attention_layer_ids=full_attention_layer_ids,
enable_kvcache_transpose=False, enable_kvcache_transpose=False,
device=device, device=device,
enable_memory_saver=False,
mamba_pool=None, mamba_pool=None,
) )
assert pool._transfer_full_attention_id(global_interval - 1) == 0 assert pool._transfer_full_attention_id(global_interval - 1) == 0
...@@ -175,7 +174,6 @@ class TestMamba(unittest.TestCase): ...@@ -175,7 +174,6 @@ class TestMamba(unittest.TestCase):
full_attention_layer_ids=full_attention_layer_ids, full_attention_layer_ids=full_attention_layer_ids,
enable_kvcache_transpose=False, enable_kvcache_transpose=False,
device=device, device=device,
enable_memory_saver=False,
mamba_pool=req_to_token_pool.mamba_pool, mamba_pool=req_to_token_pool.mamba_pool,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment