Unverified Commit 86497d99 authored by huangtingwei's avatar huangtingwei Committed by GitHub
Browse files

fix page first per layer pf2lf kernel (#8915)


Co-authored-by: default avatarZhiqiang Xie <xiezhq@stanford.edu>
parent 5c31b35d
...@@ -358,6 +358,7 @@ class MHATokenToKVPoolHost(HostKVCache): ...@@ -358,6 +358,7 @@ class MHATokenToKVPoolHost(HostKVCache):
dst_v=device_pool.v_buffer[layer_id], dst_v=device_pool.v_buffer[layer_id],
src_indices=host_indices, src_indices=host_indices,
dst_indices=device_indices, dst_indices=device_indices,
layer_id=layer_id,
item_size=self.token_stride_size, item_size=self.token_stride_size,
src_layout_dim=self.layout_dim, src_layout_dim=self.layout_dim,
) )
...@@ -585,6 +586,7 @@ class MLATokenToKVPoolHost(HostKVCache): ...@@ -585,6 +586,7 @@ class MLATokenToKVPoolHost(HostKVCache):
dst=device_pool.kv_buffer[layer_id], dst=device_pool.kv_buffer[layer_id],
src_indices=host_indices, src_indices=host_indices,
dst_indices=device_indices, dst_indices=device_indices,
layer_id=layer_id,
item_size=self.token_stride_size, item_size=self.token_stride_size,
src_layout_dim=self.layout_dim, src_layout_dim=self.layout_dim,
) )
......
...@@ -250,7 +250,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -250,7 +250,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.impl("transfer_kv_per_layer", torch::kCUDA, &transfer_kv_per_layer); m.impl("transfer_kv_per_layer", torch::kCUDA, &transfer_kv_per_layer);
m.def( m.def(
"transfer_kv_per_layer_pf_lf(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor " "transfer_kv_per_layer_pf_lf(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"dst_indices, int item_size, int src_layout_dim, int block_quota, int num_warps_per_block) -> ()"); "dst_indices, int layer_id, int item_size, int src_layout_dim, int block_quota, int num_warps_per_block) -> ()");
m.impl("transfer_kv_per_layer_pf_lf", torch::kCUDA, &transfer_kv_per_layer_pf_lf); m.impl("transfer_kv_per_layer_pf_lf", torch::kCUDA, &transfer_kv_per_layer_pf_lf);
m.def( m.def(
"transfer_kv_all_layer(Tensor src_k_layers, Tensor dst_k_layers, Tensor src_v_layers, Tensor dst_v_layers, " "transfer_kv_all_layer(Tensor src_k_layers, Tensor dst_k_layers, Tensor src_v_layers, Tensor dst_v_layers, "
...@@ -267,8 +267,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -267,8 +267,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"block_quota, int num_warps_per_block) -> ()"); "block_quota, int num_warps_per_block) -> ()");
m.impl("transfer_kv_per_layer_mla", torch::kCUDA, &transfer_kv_per_layer_mla); m.impl("transfer_kv_per_layer_mla", torch::kCUDA, &transfer_kv_per_layer_mla);
m.def( m.def(
"transfer_kv_per_layer_mla_pf_lf(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int item_size, " "transfer_kv_per_layer_mla_pf_lf(Tensor src, Tensor dst, Tensor src_indices, Tensor dst_indices, int layer_id, "
"int src_layout_dim, int block_quota, int num_warps_per_block) -> ()"); "int item_size, int src_layout_dim, int block_quota, int num_warps_per_block) -> ()");
m.impl("transfer_kv_per_layer_mla_pf_lf", torch::kCUDA, &transfer_kv_per_layer_mla_pf_lf); m.impl("transfer_kv_per_layer_mla_pf_lf", torch::kCUDA, &transfer_kv_per_layer_mla_pf_lf);
m.def( m.def(
"transfer_kv_all_layer_mla(Tensor src_layers, Tensor dst_layers, Tensor src_indices, Tensor dst_indices, int " "transfer_kv_all_layer_mla(Tensor src_layers, Tensor dst_layers, Tensor src_indices, Tensor dst_indices, int "
......
...@@ -210,6 +210,7 @@ void transfer_kv_per_layer_pf_lf( ...@@ -210,6 +210,7 @@ void transfer_kv_per_layer_pf_lf(
at::Tensor dst_v, at::Tensor dst_v,
const at::Tensor src_indices, const at::Tensor src_indices,
const at::Tensor dst_indices, const at::Tensor dst_indices,
int64_t layer_id,
int64_t item_size, int64_t item_size,
int64_t src_layout_dim, int64_t src_layout_dim,
int64_t block_quota, int64_t block_quota,
...@@ -222,7 +223,7 @@ void transfer_kv_per_layer_pf_lf( ...@@ -222,7 +223,7 @@ void transfer_kv_per_layer_pf_lf(
dst_v, dst_v,
src_indices, src_indices,
dst_indices, dst_indices,
0, layer_id,
1, 1,
item_size, item_size,
src_layout_dim, src_layout_dim,
...@@ -336,6 +337,7 @@ void transfer_kv_per_layer_mla_pf_lf( ...@@ -336,6 +337,7 @@ void transfer_kv_per_layer_mla_pf_lf(
at::Tensor dst, at::Tensor dst,
const at::Tensor src_indices, const at::Tensor src_indices,
const at::Tensor dst_indices, const at::Tensor dst_indices,
int64_t layer_id,
int64_t item_size, int64_t item_size,
int64_t src_layout_dim, int64_t src_layout_dim,
int64_t block_quota, int64_t block_quota,
...@@ -348,7 +350,7 @@ void transfer_kv_per_layer_mla_pf_lf( ...@@ -348,7 +350,7 @@ void transfer_kv_per_layer_mla_pf_lf(
empty, empty,
src_indices, src_indices,
dst_indices, dst_indices,
0, layer_id,
1, 1,
item_size, item_size,
src_layout_dim, src_layout_dim,
......
...@@ -419,6 +419,7 @@ void transfer_kv_per_layer_pf_lf( ...@@ -419,6 +419,7 @@ void transfer_kv_per_layer_pf_lf(
at::Tensor dst_v, at::Tensor dst_v,
const at::Tensor src_indices, const at::Tensor src_indices,
const at::Tensor dst_indices, const at::Tensor dst_indices,
int64_t layer_id,
int64_t item_size, int64_t item_size,
int64_t src_layout_dim, int64_t src_layout_dim,
int64_t block_quota, int64_t block_quota,
...@@ -463,6 +464,7 @@ void transfer_kv_per_layer_mla_pf_lf( ...@@ -463,6 +464,7 @@ void transfer_kv_per_layer_mla_pf_lf(
at::Tensor dst, at::Tensor dst,
const at::Tensor src_indices, const at::Tensor src_indices,
const at::Tensor dst_indices, const at::Tensor dst_indices,
int64_t layer_id,
int64_t item_size, int64_t item_size,
int64_t src_layout_dim, int64_t src_layout_dim,
int64_t block_quota, int64_t block_quota,
......
...@@ -34,6 +34,7 @@ def transfer_kv_per_layer_pf_lf( ...@@ -34,6 +34,7 @@ def transfer_kv_per_layer_pf_lf(
dst_v: torch.Tensor, dst_v: torch.Tensor,
src_indices: torch.Tensor, src_indices: torch.Tensor,
dst_indices: torch.Tensor, dst_indices: torch.Tensor,
layer_id: int,
item_size: int, item_size: int,
src_layout_dim: int, src_layout_dim: int,
block_quota: int = 2, block_quota: int = 2,
...@@ -46,6 +47,7 @@ def transfer_kv_per_layer_pf_lf( ...@@ -46,6 +47,7 @@ def transfer_kv_per_layer_pf_lf(
dst_v, dst_v,
src_indices, src_indices,
dst_indices, dst_indices,
layer_id,
item_size, item_size,
src_layout_dim, src_layout_dim,
block_quota, block_quota,
...@@ -144,6 +146,7 @@ def transfer_kv_per_layer_mla_pf_lf( ...@@ -144,6 +146,7 @@ def transfer_kv_per_layer_mla_pf_lf(
dst: torch.Tensor, dst: torch.Tensor,
src_indices: torch.Tensor, src_indices: torch.Tensor,
dst_indices: torch.Tensor, dst_indices: torch.Tensor,
layer_id: int,
item_size: int, item_size: int,
src_layout_dim: int, src_layout_dim: int,
block_quota: int = 2, block_quota: int = 2,
...@@ -154,6 +157,7 @@ def transfer_kv_per_layer_mla_pf_lf( ...@@ -154,6 +157,7 @@ def transfer_kv_per_layer_mla_pf_lf(
dst, dst,
src_indices, src_indices,
dst_indices, dst_indices,
layer_id,
item_size, item_size,
src_layout_dim, src_layout_dim,
block_quota, block_quota,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment