Unverified Commit 5c31b35d authored by cctry's avatar cctry Committed by GitHub
Browse files

[hicache] Optimization for DMA copy (#8245)

parent ef48d554
......@@ -433,7 +433,9 @@ class HiCacheController:
if self.io_backend == "kernel":
return host_indices.to(self.mem_pool_device.device), device_indices
elif self.io_backend == "direct":
return host_indices, device_indices.cpu()
device_indices = device_indices.cpu()
host_indices, idx = host_indices.sort()
return host_indices, device_indices.index_select(0, idx)
else:
raise ValueError(f"Unsupported io backend")
......
......@@ -451,15 +451,33 @@ void transfer_kv_direct(
auto src_indices_cpu = src_indices.cpu();
auto dst_indices_cpu = dst_indices.cpu();
const int64_t num_pages = src_indices_cpu.size(0) / page_size;
const auto num_indices = src_indices_cpu.numel();
const int64_t num_layers = src_layers.size();
int64_t* src_indices_ptr = src_indices_cpu.data_ptr<int64_t>();
int64_t* dst_indices_ptr = dst_indices_cpu.data_ptr<int64_t>();
for (int64_t i = 0; i < num_pages; ++i) {
auto src_index = src_indices_cpu[i * page_size].item<int64_t>();
auto dst_index = dst_indices_cpu[i * page_size].item<int64_t>();
int64_t start_index = 0;
int64_t end_index = 0;
for (int64_t i = 0; i < num_indices; ++i) {
if (i < num_indices - 1) {
auto src_diff = src_indices_ptr[i + 1] - src_indices_ptr[i];
auto dst_diff = dst_indices_ptr[i + 1] - dst_indices_ptr[i];
if (src_diff == 1 && dst_diff == 1) {
continue;
}
end_index = i + 1;
} else { // last batch
end_index = num_indices;
}
auto src_index = src_indices_ptr[start_index];
auto dst_index = dst_indices_ptr[start_index];
auto num_tokens = end_index - start_index;
for (int64_t j = 0; j < num_layers; ++j) {
transfer_page_direct(src_layers[j], dst_layers[j], src_index, dst_index, page_size);
transfer_page_direct(src_layers[j], dst_layers[j], src_index, dst_index, num_tokens);
}
start_index = end_index;
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment