Unverified Commit f2887498 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Simplify memory pool (#9033)

parent 8ecf6b9d
...@@ -1251,18 +1251,8 @@ class ModelRunner: ...@@ -1251,18 +1251,8 @@ class ModelRunner:
# Draft worker shares req_to_token_pool with the target worker. # Draft worker shares req_to_token_pool with the target worker.
assert self.is_draft_worker assert self.is_draft_worker
if self.server_args.attention_backend == "ascend" and not self.use_mla_backend: if self.server_args.attention_backend == "ascend":
self.token_to_kv_pool = AscendTokenToKVPool( if self.use_mla_backend:
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
)
elif self.server_args.attention_backend == "ascend" and self.use_mla_backend:
self.token_to_kv_pool = AscendMLAPagedTokenToKVPool( self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
self.max_total_num_tokens, self.max_total_num_tokens,
page_size=self.page_size, page_size=self.page_size,
...@@ -1275,6 +1265,19 @@ class ModelRunner: ...@@ -1275,6 +1265,19 @@ class ModelRunner:
start_layer=self.start_layer, start_layer=self.start_layer,
end_layer=self.end_layer, end_layer=self.end_layer,
) )
else:
self.token_to_kv_pool = AscendTokenToKVPool(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads(
get_attention_tp_size()
),
head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
)
elif self.use_mla_backend: elif self.use_mla_backend:
self.token_to_kv_pool = MLATokenToKVPool( self.token_to_kv_pool = MLATokenToKVPool(
self.max_total_num_tokens, self.max_total_num_tokens,
...@@ -1333,6 +1336,7 @@ class ModelRunner: ...@@ -1333,6 +1336,7 @@ class ModelRunner:
end_layer=self.end_layer, end_layer=self.end_layer,
) )
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
if self.token_to_kv_pool_allocator is None: if self.token_to_kv_pool_allocator is None:
if self.page_size == 1: if self.page_size == 1:
if self.is_hybrid: if self.is_hybrid:
...@@ -1342,8 +1346,7 @@ class ModelRunner: ...@@ -1342,8 +1346,7 @@ class ModelRunner:
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
device=self.device, device=self.device,
kvcache=self.token_to_kv_pool, kvcache=self.token_to_kv_pool,
need_sort=self.server_args.disaggregation_mode need_sort=need_sort,
in ("decode", "prefill"),
) )
else: else:
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator( self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
...@@ -1351,29 +1354,26 @@ class ModelRunner: ...@@ -1351,29 +1354,26 @@ class ModelRunner:
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
device=self.device, device=self.device,
kvcache=self.token_to_kv_pool, kvcache=self.token_to_kv_pool,
need_sort=self.server_args.disaggregation_mode need_sort=need_sort,
in ("decode", "prefill"),
) )
else: else:
if _is_npu: if not _is_npu:
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator( self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
self.max_total_num_tokens, self.max_total_num_tokens,
page_size=self.page_size, page_size=self.page_size,
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
device=self.device, device=self.device,
kvcache=self.token_to_kv_pool, kvcache=self.token_to_kv_pool,
need_sort=self.server_args.disaggregation_mode need_sort=need_sort,
in ("decode", "prefill"),
) )
else: else:
self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator( self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
self.max_total_num_tokens, self.max_total_num_tokens,
page_size=self.page_size, page_size=self.page_size,
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
device=self.device, device=self.device,
kvcache=self.token_to_kv_pool, kvcache=self.token_to_kv_pool,
need_sort=self.server_args.disaggregation_mode need_sort=need_sort,
in ("decode", "prefill"),
) )
else: else:
assert self.is_draft_worker assert self.is_draft_worker
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment