"tutorials/basics/1_first.py" did not exist on "cab1fdf2ec8bb5b281db804dc8f5d282b653d5f8"
Unverified Commit f2887498 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Simplify memory pool (#9033)

parent 8ecf6b9d
...@@ -1251,30 +1251,33 @@ class ModelRunner: ...@@ -1251,30 +1251,33 @@ class ModelRunner:
# Draft worker shares req_to_token_pool with the target worker. # Draft worker shares req_to_token_pool with the target worker.
assert self.is_draft_worker assert self.is_draft_worker
if self.server_args.attention_backend == "ascend" and not self.use_mla_backend: if self.server_args.attention_backend == "ascend":
self.token_to_kv_pool = AscendTokenToKVPool( if self.use_mla_backend:
self.max_total_num_tokens, self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
page_size=self.page_size, self.max_total_num_tokens,
dtype=self.kv_cache_dtype, page_size=self.page_size,
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()), dtype=self.kv_cache_dtype,
head_dim=self.model_config.head_dim, kv_lora_rank=self.model_config.kv_lora_rank,
layer_num=self.model_config.num_hidden_layers, qk_rope_head_dim=self.model_config.qk_rope_head_dim,
device=self.device, layer_num=self.num_effective_layers,
enable_memory_saver=self.server_args.enable_memory_saver, device=self.device,
) enable_memory_saver=self.server_args.enable_memory_saver,
elif self.server_args.attention_backend == "ascend" and self.use_mla_backend: start_layer=self.start_layer,
self.token_to_kv_pool = AscendMLAPagedTokenToKVPool( end_layer=self.end_layer,
self.max_total_num_tokens, )
page_size=self.page_size, else:
dtype=self.kv_cache_dtype, self.token_to_kv_pool = AscendTokenToKVPool(
kv_lora_rank=self.model_config.kv_lora_rank, self.max_total_num_tokens,
qk_rope_head_dim=self.model_config.qk_rope_head_dim, page_size=self.page_size,
layer_num=self.num_effective_layers, dtype=self.kv_cache_dtype,
device=self.device, head_num=self.model_config.get_num_kv_heads(
enable_memory_saver=self.server_args.enable_memory_saver, get_attention_tp_size()
start_layer=self.start_layer, ),
end_layer=self.end_layer, head_dim=self.model_config.head_dim,
) layer_num=self.model_config.num_hidden_layers,
device=self.device,
enable_memory_saver=self.server_args.enable_memory_saver,
)
elif self.use_mla_backend: elif self.use_mla_backend:
self.token_to_kv_pool = MLATokenToKVPool( self.token_to_kv_pool = MLATokenToKVPool(
self.max_total_num_tokens, self.max_total_num_tokens,
...@@ -1333,6 +1336,7 @@ class ModelRunner: ...@@ -1333,6 +1336,7 @@ class ModelRunner:
end_layer=self.end_layer, end_layer=self.end_layer,
) )
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
if self.token_to_kv_pool_allocator is None: if self.token_to_kv_pool_allocator is None:
if self.page_size == 1: if self.page_size == 1:
if self.is_hybrid: if self.is_hybrid:
...@@ -1342,8 +1346,7 @@ class ModelRunner: ...@@ -1342,8 +1346,7 @@ class ModelRunner:
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
device=self.device, device=self.device,
kvcache=self.token_to_kv_pool, kvcache=self.token_to_kv_pool,
need_sort=self.server_args.disaggregation_mode need_sort=need_sort,
in ("decode", "prefill"),
) )
else: else:
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator( self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
...@@ -1351,29 +1354,26 @@ class ModelRunner: ...@@ -1351,29 +1354,26 @@ class ModelRunner:
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
device=self.device, device=self.device,
kvcache=self.token_to_kv_pool, kvcache=self.token_to_kv_pool,
need_sort=self.server_args.disaggregation_mode need_sort=need_sort,
in ("decode", "prefill"),
) )
else: else:
if _is_npu: if not _is_npu:
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator( self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
self.max_total_num_tokens, self.max_total_num_tokens,
page_size=self.page_size, page_size=self.page_size,
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
device=self.device, device=self.device,
kvcache=self.token_to_kv_pool, kvcache=self.token_to_kv_pool,
need_sort=self.server_args.disaggregation_mode need_sort=need_sort,
in ("decode", "prefill"),
) )
else: else:
self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator( self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
self.max_total_num_tokens, self.max_total_num_tokens,
page_size=self.page_size, page_size=self.page_size,
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
device=self.device, device=self.device,
kvcache=self.token_to_kv_pool, kvcache=self.token_to_kv_pool,
need_sort=self.server_args.disaggregation_mode need_sort=need_sort,
in ("decode", "prefill"),
) )
else: else:
assert self.is_draft_worker assert self.is_draft_worker
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment