Unverified Commit f2887498 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Simplify memory pool (#9033)

parent 8ecf6b9d
......@@ -1251,30 +1251,33 @@ class ModelRunner:
# Draft worker shares req_to_token_pool with the target worker.
assert self.is_draft_worker
if self.server_args.attention_backend == "ascend" and not self.use_mla_backend:
self.token_to_kv_pool = AscendTokenToKVPool(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
)
elif self.server_args.attention_backend == "ascend" and self.use_mla_backend:
self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
kv_lora_rank=self.model_config.kv_lora_rank,
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
layer_num=self.num_effective_layers,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
start_layer=self.start_layer,
end_layer=self.end_layer,
)
if self.server_args.attention_backend == "ascend":
if self.use_mla_backend:
self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
kv_lora_rank=self.model_config.kv_lora_rank,
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
layer_num=self.num_effective_layers,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
start_layer=self.start_layer,
end_layer=self.end_layer,
)
else:
self.token_to_kv_pool = AscendTokenToKVPool(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
head_num=self.model_config.get_num_kv_heads(
get_attention_tp_size()
),
head_dim=self.model_config.head_dim,
layer_num=self.model_config.num_hidden_layers,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
)
elif self.use_mla_backend:
self.token_to_kv_pool = MLATokenToKVPool(
self.max_total_num_tokens,
......@@ -1333,6 +1336,7 @@ class ModelRunner:
end_layer=self.end_layer,
)
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
if self.token_to_kv_pool_allocator is None:
if self.page_size == 1:
if self.is_hybrid:
......@@ -1342,8 +1346,7 @@ class ModelRunner:
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
need_sort=self.server_args.disaggregation_mode
in ("decode", "prefill"),
need_sort=need_sort,
)
else:
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
......@@ -1351,29 +1354,26 @@ class ModelRunner:
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
need_sort=self.server_args.disaggregation_mode
in ("decode", "prefill"),
need_sort=need_sort,
)
else:
if _is_npu:
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
if not _is_npu:
self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
need_sort=self.server_args.disaggregation_mode
in ("decode", "prefill"),
need_sort=need_sort,
)
else:
self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
self.max_total_num_tokens,
page_size=self.page_size,
dtype=self.kv_cache_dtype,
device=self.device,
kvcache=self.token_to_kv_pool,
need_sort=self.server_args.disaggregation_mode
in ("decode", "prefill"),
need_sort=need_sort,
)
else:
assert self.is_draft_worker
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment