Unverified Commit b4c2c421 authored by huangtingwei's avatar huangtingwei Committed by GitHub
Browse files

support memory_pool_host page first direct layout (#10031)


Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent 53ca1552
...@@ -532,9 +532,12 @@ class HiCacheController: ...@@ -532,9 +532,12 @@ class HiCacheController:
host_indices = host_indices.to(self.device, non_blocking=True) host_indices = host_indices.to(self.device, non_blocking=True)
return host_indices, device_indices return host_indices, device_indices
elif self.io_backend == "direct": elif self.io_backend == "direct":
device_indices = device_indices.cpu() if self.mem_pool_host.layout == "layer_first":
host_indices, idx = host_indices.sort() device_indices = device_indices.cpu()
return host_indices, device_indices.index_select(0, idx) host_indices, idx = host_indices.sort()
return host_indices, device_indices.index_select(0, idx)
elif self.mem_pool_host.layout == "page_first_direct":
return host_indices, device_indices.cpu()
else: else:
raise ValueError(f"Unsupported io backend") raise ValueError(f"Unsupported io backend")
......
...@@ -16,11 +16,13 @@ _is_xpu = is_xpu() ...@@ -16,11 +16,13 @@ _is_xpu = is_xpu()
if not (_is_npu or _is_xpu): if not (_is_npu or _is_xpu):
from sgl_kernel.kvcacheio import ( from sgl_kernel.kvcacheio import (
transfer_kv_all_layer, transfer_kv_all_layer,
transfer_kv_all_layer_direct_lf_pf,
transfer_kv_all_layer_lf_pf, transfer_kv_all_layer_lf_pf,
transfer_kv_all_layer_mla, transfer_kv_all_layer_mla,
transfer_kv_all_layer_mla_lf_pf, transfer_kv_all_layer_mla_lf_pf,
transfer_kv_direct, transfer_kv_direct,
transfer_kv_per_layer, transfer_kv_per_layer,
transfer_kv_per_layer_direct_pf_lf,
transfer_kv_per_layer_mla, transfer_kv_per_layer_mla,
transfer_kv_per_layer_mla_pf_lf, transfer_kv_per_layer_mla_pf_lf,
transfer_kv_per_layer_pf_lf, transfer_kv_per_layer_pf_lf,
...@@ -78,6 +80,7 @@ class HostKVCache(abc.ABC): ...@@ -78,6 +80,7 @@ class HostKVCache(abc.ABC):
self.size = int(device_pool.size * host_to_device_ratio) self.size = int(device_pool.size * host_to_device_ratio)
# Align the host memory pool size to the page size # Align the host memory pool size to the page size
self.size = self.size - (self.size % self.page_size) self.size = self.size - (self.size % self.page_size)
self.page_num = self.size // self.page_size
self.start_layer = device_pool.start_layer self.start_layer = device_pool.start_layer
self.end_layer = device_pool.end_layer self.end_layer = device_pool.end_layer
...@@ -317,6 +320,15 @@ class MHATokenToKVPoolHost(HostKVCache): ...@@ -317,6 +320,15 @@ class MHATokenToKVPoolHost(HostKVCache):
dims = (2, self.layer_num, self.size, self.head_num, self.head_dim) dims = (2, self.layer_num, self.size, self.head_num, self.head_dim)
elif self.layout == "page_first": elif self.layout == "page_first":
dims = (2, self.size, self.layer_num, self.head_num, self.head_dim) dims = (2, self.size, self.layer_num, self.head_num, self.head_dim)
elif self.layout == "page_first_direct":
dims = (
2,
self.page_num,
self.layer_num,
self.page_size,
self.head_num,
self.head_dim,
)
else: else:
raise ValueError(f"Unsupported layout: {self.layout}") raise ValueError(f"Unsupported layout: {self.layout}")
self.token_stride_size = self.head_num * self.head_dim * self.dtype.itemsize self.token_stride_size = self.head_num * self.head_dim * self.dtype.itemsize
...@@ -370,19 +382,31 @@ class MHATokenToKVPoolHost(HostKVCache): ...@@ -370,19 +382,31 @@ class MHATokenToKVPoolHost(HostKVCache):
else: else:
raise ValueError(f"Unsupported layout: {self.layout}") raise ValueError(f"Unsupported layout: {self.layout}")
elif io_backend == "direct": elif io_backend == "direct":
assert ( if self.layout == "layer_first":
self.layout == "layer_first" transfer_kv_direct(
), f"Direct IO backend only supports layer_first layout." src_layers=[self.k_buffer[layer_id], self.v_buffer[layer_id]],
transfer_kv_direct( dst_layers=[
src_layers=[self.k_buffer[layer_id], self.v_buffer[layer_id]], device_pool.k_buffer[layer_id],
dst_layers=[ device_pool.v_buffer[layer_id],
device_pool.k_buffer[layer_id], ],
device_pool.v_buffer[layer_id], src_indices=host_indices,
], dst_indices=device_indices,
src_indices=host_indices, page_size=self.page_size,
dst_indices=device_indices, )
page_size=self.page_size, elif self.layout == "page_first_direct":
) transfer_kv_per_layer_direct_pf_lf(
src_ptrs=[self.k_buffer, self.v_buffer],
dst_ptrs=[
device_pool.k_buffer[layer_id],
device_pool.v_buffer[layer_id],
],
src_indices=host_indices,
dst_indices=device_indices,
layer_id=layer_id,
page_size=self.page_size,
)
else:
raise ValueError(f"Unsupported layout: {self.layout}")
else: else:
raise ValueError(f"Unsupported IO backend: {io_backend}") raise ValueError(f"Unsupported IO backend: {io_backend}")
...@@ -416,16 +440,24 @@ class MHATokenToKVPoolHost(HostKVCache): ...@@ -416,16 +440,24 @@ class MHATokenToKVPoolHost(HostKVCache):
else: else:
raise ValueError(f"Unsupported layout: {self.layout}") raise ValueError(f"Unsupported layout: {self.layout}")
elif io_backend == "direct": elif io_backend == "direct":
assert ( if self.layout == "layer_first":
self.layout == "layer_first" transfer_kv_direct(
), f"Direct IO backend only supports layer_first layout." src_layers=device_pool.k_buffer + device_pool.v_buffer,
transfer_kv_direct( dst_layers=self.k_data_refs + self.v_data_refs,
src_layers=device_pool.k_buffer + device_pool.v_buffer, src_indices=device_indices,
dst_layers=self.k_data_refs + self.v_data_refs, dst_indices=host_indices,
src_indices=device_indices, page_size=self.page_size,
dst_indices=host_indices, )
page_size=self.page_size, elif self.layout == "page_first_direct":
) transfer_kv_all_layer_direct_lf_pf(
src_ptrs=device_pool.k_buffer + device_pool.v_buffer,
dst_ptrs=[self.k_buffer, self.v_buffer],
src_indices=device_indices,
dst_indices=host_indices,
page_size=self.page_size,
)
else:
raise ValueError(f"Unsupported layout: {self.layout}")
else: else:
raise ValueError(f"Unsupported IO backend: {io_backend}") raise ValueError(f"Unsupported IO backend: {io_backend}")
...@@ -580,6 +612,14 @@ class MLATokenToKVPoolHost(HostKVCache): ...@@ -580,6 +612,14 @@ class MLATokenToKVPoolHost(HostKVCache):
1, 1,
self.kv_lora_rank + self.qk_rope_head_dim, self.kv_lora_rank + self.qk_rope_head_dim,
) )
elif self.layout == "page_first_direct":
dims = (
self.page_num,
self.layer_num,
self.page_size,
1,
self.kv_lora_rank + self.qk_rope_head_dim,
)
else: else:
raise ValueError(f"Unsupported layout: {self.layout}") raise ValueError(f"Unsupported layout: {self.layout}")
self.token_stride_size = ( self.token_stride_size = (
...@@ -619,16 +659,25 @@ class MLATokenToKVPoolHost(HostKVCache): ...@@ -619,16 +659,25 @@ class MLATokenToKVPoolHost(HostKVCache):
else: else:
raise ValueError(f"Unsupported layout: {self.layout}") raise ValueError(f"Unsupported layout: {self.layout}")
elif io_backend == "direct": elif io_backend == "direct":
assert ( if self.layout == "layer_first":
self.layout == "layer_first" transfer_kv_direct(
), f"Direct IO backend only supports layer_first layout." src_layers=[self.kv_buffer[layer_id]],
transfer_kv_direct( dst_layers=[device_pool.kv_buffer[layer_id]],
src_layers=[self.kv_buffer[layer_id]], src_indices=host_indices,
dst_layers=[device_pool.kv_buffer[layer_id]], dst_indices=device_indices,
src_indices=host_indices, page_size=self.page_size,
dst_indices=device_indices, )
page_size=self.page_size, elif self.layout == "page_first_direct":
) transfer_kv_per_layer_direct_pf_lf(
src_ptrs=[self.kv_buffer],
dst_ptrs=[device_pool.kv_buffer[layer_id]],
src_indices=host_indices,
dst_indices=device_indices,
layer_id=layer_id,
page_size=self.page_size,
)
else:
raise ValueError(f"Unsupported layout: {self.layout}")
def backup_from_device_all_layer( def backup_from_device_all_layer(
self, device_pool, host_indices, device_indices, io_backend self, device_pool, host_indices, device_indices, io_backend
...@@ -656,16 +705,24 @@ class MLATokenToKVPoolHost(HostKVCache): ...@@ -656,16 +705,24 @@ class MLATokenToKVPoolHost(HostKVCache):
else: else:
raise ValueError(f"Unsupported layout: {self.layout}") raise ValueError(f"Unsupported layout: {self.layout}")
elif io_backend == "direct": elif io_backend == "direct":
assert ( if self.layout == "layer_first":
self.layout == "layer_first" transfer_kv_direct(
), f"Direct IO backend only supports layer_first layout." src_layers=device_pool.kv_buffer,
transfer_kv_direct( dst_layers=self.data_refs,
src_layers=device_pool.kv_buffer, src_indices=device_indices,
dst_layers=self.data_refs, dst_indices=host_indices,
src_indices=device_indices, page_size=self.page_size,
dst_indices=host_indices, )
page_size=self.page_size, elif self.layout == "page_first_direct":
) transfer_kv_all_layer_direct_lf_pf(
src_ptrs=device_pool.kv_buffer,
dst_ptrs=[self.kv_buffer],
src_indices=device_indices,
dst_indices=host_indices,
page_size=self.page_size,
)
else:
raise ValueError(f"Unsupported layout: {self.layout}")
else: else:
raise ValueError(f"Unsupported IO backend: {io_backend}") raise ValueError(f"Unsupported IO backend: {io_backend}")
......
...@@ -721,6 +721,13 @@ class ServerArgs: ...@@ -721,6 +721,13 @@ class ServerArgs:
self.hicache_io_backend = "kernel" self.hicache_io_backend = "kernel"
self.hicache_mem_layout = "page_first" self.hicache_mem_layout = "page_first"
if self.hicache_mem_layout == "page_first_direct":
if self.hicache_io_backend != "direct":
self.hicache_io_backend = "direct"
logger.warning(
"Page first direct layout only support direct io backend"
)
# Speculative Decoding # Speculative Decoding
if self.speculative_algorithm == "NEXTN": if self.speculative_algorithm == "NEXTN":
# NEXTN shares the same implementation of EAGLE # NEXTN shares the same implementation of EAGLE
...@@ -1779,7 +1786,7 @@ class ServerArgs: ...@@ -1779,7 +1786,7 @@ class ServerArgs:
parser.add_argument( parser.add_argument(
"--hicache-mem-layout", "--hicache-mem-layout",
type=str, type=str,
choices=["layer_first", "page_first"], choices=["layer_first", "page_first", "page_first_direct"],
default=ServerArgs.hicache_mem_layout, default=ServerArgs.hicache_mem_layout,
help="The layout of host memory pool for hierarchical cache.", help="The layout of host memory pool for hierarchical cache.",
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment