Unverified Commit b4c2c421 authored by huangtingwei's avatar huangtingwei Committed by GitHub
Browse files

support memory_pool_host page first direct layout (#10031)


Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent 53ca1552
......@@ -532,9 +532,12 @@ class HiCacheController:
host_indices = host_indices.to(self.device, non_blocking=True)
return host_indices, device_indices
elif self.io_backend == "direct":
device_indices = device_indices.cpu()
host_indices, idx = host_indices.sort()
return host_indices, device_indices.index_select(0, idx)
if self.mem_pool_host.layout == "layer_first":
device_indices = device_indices.cpu()
host_indices, idx = host_indices.sort()
return host_indices, device_indices.index_select(0, idx)
elif self.mem_pool_host.layout == "page_first_direct":
return host_indices, device_indices.cpu()
else:
raise ValueError(f"Unsupported io backend")
......
......@@ -16,11 +16,13 @@ _is_xpu = is_xpu()
if not (_is_npu or _is_xpu):
from sgl_kernel.kvcacheio import (
transfer_kv_all_layer,
transfer_kv_all_layer_direct_lf_pf,
transfer_kv_all_layer_lf_pf,
transfer_kv_all_layer_mla,
transfer_kv_all_layer_mla_lf_pf,
transfer_kv_direct,
transfer_kv_per_layer,
transfer_kv_per_layer_direct_pf_lf,
transfer_kv_per_layer_mla,
transfer_kv_per_layer_mla_pf_lf,
transfer_kv_per_layer_pf_lf,
......@@ -78,6 +80,7 @@ class HostKVCache(abc.ABC):
self.size = int(device_pool.size * host_to_device_ratio)
# Align the host memory pool size to the page size
self.size = self.size - (self.size % self.page_size)
self.page_num = self.size // self.page_size
self.start_layer = device_pool.start_layer
self.end_layer = device_pool.end_layer
......@@ -317,6 +320,15 @@ class MHATokenToKVPoolHost(HostKVCache):
dims = (2, self.layer_num, self.size, self.head_num, self.head_dim)
elif self.layout == "page_first":
dims = (2, self.size, self.layer_num, self.head_num, self.head_dim)
elif self.layout == "page_first_direct":
dims = (
2,
self.page_num,
self.layer_num,
self.page_size,
self.head_num,
self.head_dim,
)
else:
raise ValueError(f"Unsupported layout: {self.layout}")
self.token_stride_size = self.head_num * self.head_dim * self.dtype.itemsize
......@@ -370,19 +382,31 @@ class MHATokenToKVPoolHost(HostKVCache):
else:
raise ValueError(f"Unsupported layout: {self.layout}")
elif io_backend == "direct":
assert (
self.layout == "layer_first"
), f"Direct IO backend only supports layer_first layout."
transfer_kv_direct(
src_layers=[self.k_buffer[layer_id], self.v_buffer[layer_id]],
dst_layers=[
device_pool.k_buffer[layer_id],
device_pool.v_buffer[layer_id],
],
src_indices=host_indices,
dst_indices=device_indices,
page_size=self.page_size,
)
if self.layout == "layer_first":
transfer_kv_direct(
src_layers=[self.k_buffer[layer_id], self.v_buffer[layer_id]],
dst_layers=[
device_pool.k_buffer[layer_id],
device_pool.v_buffer[layer_id],
],
src_indices=host_indices,
dst_indices=device_indices,
page_size=self.page_size,
)
elif self.layout == "page_first_direct":
transfer_kv_per_layer_direct_pf_lf(
src_ptrs=[self.k_buffer, self.v_buffer],
dst_ptrs=[
device_pool.k_buffer[layer_id],
device_pool.v_buffer[layer_id],
],
src_indices=host_indices,
dst_indices=device_indices,
layer_id=layer_id,
page_size=self.page_size,
)
else:
raise ValueError(f"Unsupported layout: {self.layout}")
else:
raise ValueError(f"Unsupported IO backend: {io_backend}")
......@@ -416,16 +440,24 @@ class MHATokenToKVPoolHost(HostKVCache):
else:
raise ValueError(f"Unsupported layout: {self.layout}")
elif io_backend == "direct":
assert (
self.layout == "layer_first"
), f"Direct IO backend only supports layer_first layout."
transfer_kv_direct(
src_layers=device_pool.k_buffer + device_pool.v_buffer,
dst_layers=self.k_data_refs + self.v_data_refs,
src_indices=device_indices,
dst_indices=host_indices,
page_size=self.page_size,
)
if self.layout == "layer_first":
transfer_kv_direct(
src_layers=device_pool.k_buffer + device_pool.v_buffer,
dst_layers=self.k_data_refs + self.v_data_refs,
src_indices=device_indices,
dst_indices=host_indices,
page_size=self.page_size,
)
elif self.layout == "page_first_direct":
transfer_kv_all_layer_direct_lf_pf(
src_ptrs=device_pool.k_buffer + device_pool.v_buffer,
dst_ptrs=[self.k_buffer, self.v_buffer],
src_indices=device_indices,
dst_indices=host_indices,
page_size=self.page_size,
)
else:
raise ValueError(f"Unsupported layout: {self.layout}")
else:
raise ValueError(f"Unsupported IO backend: {io_backend}")
......@@ -580,6 +612,14 @@ class MLATokenToKVPoolHost(HostKVCache):
1,
self.kv_lora_rank + self.qk_rope_head_dim,
)
elif self.layout == "page_first_direct":
dims = (
self.page_num,
self.layer_num,
self.page_size,
1,
self.kv_lora_rank + self.qk_rope_head_dim,
)
else:
raise ValueError(f"Unsupported layout: {self.layout}")
self.token_stride_size = (
......@@ -619,16 +659,25 @@ class MLATokenToKVPoolHost(HostKVCache):
else:
raise ValueError(f"Unsupported layout: {self.layout}")
elif io_backend == "direct":
assert (
self.layout == "layer_first"
), f"Direct IO backend only supports layer_first layout."
transfer_kv_direct(
src_layers=[self.kv_buffer[layer_id]],
dst_layers=[device_pool.kv_buffer[layer_id]],
src_indices=host_indices,
dst_indices=device_indices,
page_size=self.page_size,
)
if self.layout == "layer_first":
transfer_kv_direct(
src_layers=[self.kv_buffer[layer_id]],
dst_layers=[device_pool.kv_buffer[layer_id]],
src_indices=host_indices,
dst_indices=device_indices,
page_size=self.page_size,
)
elif self.layout == "page_first_direct":
transfer_kv_per_layer_direct_pf_lf(
src_ptrs=[self.kv_buffer],
dst_ptrs=[device_pool.kv_buffer[layer_id]],
src_indices=host_indices,
dst_indices=device_indices,
layer_id=layer_id,
page_size=self.page_size,
)
else:
raise ValueError(f"Unsupported layout: {self.layout}")
def backup_from_device_all_layer(
self, device_pool, host_indices, device_indices, io_backend
......@@ -656,16 +705,24 @@ class MLATokenToKVPoolHost(HostKVCache):
else:
raise ValueError(f"Unsupported layout: {self.layout}")
elif io_backend == "direct":
assert (
self.layout == "layer_first"
), f"Direct IO backend only supports layer_first layout."
transfer_kv_direct(
src_layers=device_pool.kv_buffer,
dst_layers=self.data_refs,
src_indices=device_indices,
dst_indices=host_indices,
page_size=self.page_size,
)
if self.layout == "layer_first":
transfer_kv_direct(
src_layers=device_pool.kv_buffer,
dst_layers=self.data_refs,
src_indices=device_indices,
dst_indices=host_indices,
page_size=self.page_size,
)
elif self.layout == "page_first_direct":
transfer_kv_all_layer_direct_lf_pf(
src_ptrs=device_pool.kv_buffer,
dst_ptrs=[self.kv_buffer],
src_indices=device_indices,
dst_indices=host_indices,
page_size=self.page_size,
)
else:
raise ValueError(f"Unsupported layout: {self.layout}")
else:
raise ValueError(f"Unsupported IO backend: {io_backend}")
......
......@@ -721,6 +721,13 @@ class ServerArgs:
self.hicache_io_backend = "kernel"
self.hicache_mem_layout = "page_first"
if self.hicache_mem_layout == "page_first_direct":
if self.hicache_io_backend != "direct":
self.hicache_io_backend = "direct"
logger.warning(
"Page first direct layout only support direct io backend"
)
# Speculative Decoding
if self.speculative_algorithm == "NEXTN":
# NEXTN shares the same implementation of EAGLE
......@@ -1779,7 +1786,7 @@ class ServerArgs:
parser.add_argument(
"--hicache-mem-layout",
type=str,
choices=["layer_first", "page_first"],
choices=["layer_first", "page_first", "page_first_direct"],
default=ServerArgs.hicache_mem_layout,
help="The layout of host memory pool for hierarchical cache.",
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment