Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
bf74389d
"docs/git@developer.sourcefind.cn:tsoc/superbenchmark.git" did not exist on "f380bc5effa34d2925c9faa5120da5dd8fd79b7d"
Commit
bf74389d
authored
Jan 04, 2026
by
PanZezhong
Browse files
issue/168 get contiguous paged kv cache
parent
f246c4f1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
102 additions
and
4 deletions
+102
-4
csrc/cache/kv_cache.cpp
csrc/cache/kv_cache.cpp
+66
-2
csrc/cache/kv_cache.hpp
csrc/cache/kv_cache.hpp
+36
-2
No files found.
csrc/cache/kv_cache.cpp
View file @
bf74389d
...
@@ -188,8 +188,7 @@ std::tuple<infinicore::Tensor, infinicore::Tensor> PagedKVCache::update(
...
@@ -188,8 +188,7 @@ std::tuple<infinicore::Tensor, infinicore::Tensor> PagedKVCache::update(
const
infinicore
::
Tensor
&
v
,
const
infinicore
::
Tensor
&
v
,
const
infinicore
::
Tensor
&
slot_mapping
)
{
const
infinicore
::
Tensor
&
slot_mapping
)
{
auto
k_cache_layer
=
k_caches_
->
narrow
({{
0
,
layer_idx
,
1
}})
->
squeeze
(
0
);
auto
&&
[
k_cache_layer
,
v_cache_layer
]
=
get_paged_kv
(
layer_idx
);
auto
v_cache_layer
=
v_caches_
->
narrow
({{
0
,
layer_idx
,
1
}})
->
squeeze
(
0
);
infinicore
::
op
::
paged_caching_
(
k
,
infinicore
::
op
::
paged_caching_
(
k
,
v
,
v
,
...
@@ -198,4 +197,69 @@ std::tuple<infinicore::Tensor, infinicore::Tensor> PagedKVCache::update(
...
@@ -198,4 +197,69 @@ std::tuple<infinicore::Tensor, infinicore::Tensor> PagedKVCache::update(
slot_mapping
);
slot_mapping
);
return
{
k_cache_layer
,
v_cache_layer
};
return
{
k_cache_layer
,
v_cache_layer
};
}
}
std
::
tuple
<
infinicore
::
Tensor
,
infinicore
::
Tensor
>
PagedKVCache
::
get_paged_kv
(
size_t
layer_idx
)
{
auto
k_cache_layer
=
k_caches_
->
narrow
({{
0
,
layer_idx
,
1
}})
->
squeeze
(
0
);
auto
v_cache_layer
=
v_caches_
->
narrow
({{
0
,
layer_idx
,
1
}})
->
squeeze
(
0
);
return
{
k_cache_layer
,
v_cache_layer
};
}
std
::
tuple
<
infinicore
::
Tensor
,
infinicore
::
Tensor
>
PagedKVCache
::
get_contiguous_kv
(
size_t
layer_idx
,
const
infinicore
::
Tensor
block_tables
,
const
infinicore
::
Tensor
cache_lens
,
const
infinicore
::
Tensor
input_offsets
,
size_t
request_id
)
{
ASSERT_EQ
(
block_tables
->
dtype
(),
infinicore
::
DataType
::
I64
);
ASSERT_EQ
(
cache_lens
->
dtype
(),
infinicore
::
DataType
::
I64
);
ASSERT_EQ
(
input_offsets
->
dtype
(),
infinicore
::
DataType
::
I64
);
auto
nreq
=
block_tables
->
size
(
0
);
auto
block_tables_cpu
=
block_tables
->
to
(
infinicore
::
Device
::
cpu
());
auto
cache_lens_cpu
=
cache_lens
->
to
(
infinicore
::
Device
::
cpu
());
auto
input_offsets_cpu
=
input_offsets
->
to
(
infinicore
::
Device
::
cpu
());
infinicore
::
context
::
syncDevice
();
// [num_blocks, num_rank_v_heads, block_size, v_dim]
auto
&&
[
k_cache_layer
,
v_cache_layer
]
=
get_paged_kv
(
layer_idx
);
auto
req
=
request_id
;
auto
cache_lens_ptr
=
reinterpret_cast
<
const
int64_t
*>
(
cache_lens_cpu
->
data
());
auto
input_offsets_ptr
=
reinterpret_cast
<
const
int64_t
*>
(
input_offsets_cpu
->
data
());
int64_t
total_len
=
cache_lens_ptr
[
req
]
+
(
input_offsets_ptr
[
req
+
1
]
-
input_offsets_ptr
[
req
]);
auto
full_k
=
infinicore
::
Tensor
::
empty
(
{
num_rank_k_heads_
,
(
size_t
)
total_len
,
k_dim_
},
k_cache_layer
->
dtype
(),
k_cache_layer
->
device
());
auto
full_v
=
infinicore
::
Tensor
::
empty
(
{
num_rank_v_heads_
,
(
size_t
)
total_len
,
v_dim_
},
v_cache_layer
->
dtype
(),
v_cache_layer
->
device
());
size_t
nblocks
=
total_len
/
block_size_
;
size_t
r
=
total_len
%
block_size_
;
for
(
size_t
b
=
0
;
b
<
nblocks
;
b
++
)
{
size_t
bid
=
*
((
int64_t
*
)(
block_tables_cpu
->
narrow
({{
0
,
req
,
1
},
{
1
,
b
,
1
}})
->
data
()));
full_k
->
narrow
({{
1
,
b
*
block_size_
,
block_size_
}})
->
copy_from
(
k_cache_layer
->
narrow
({{
0
,
bid
,
1
}})
->
squeeze
(
0
));
full_v
->
narrow
({{
1
,
b
*
block_size_
,
block_size_
}})
->
copy_from
(
v_cache_layer
->
narrow
({{
0
,
bid
,
1
}})
->
squeeze
(
0
));
}
if
(
r
>
0
)
{
size_t
bid
=
*
((
int64_t
*
)(
block_tables_cpu
->
narrow
({{
0
,
req
,
1
},
{
1
,
nblocks
,
1
}})
->
data
()));
full_k
->
narrow
({{
1
,
nblocks
*
block_size_
,
r
}})
->
copy_from
(
k_cache_layer
->
narrow
({{
0
,
bid
,
1
}})
->
squeeze
(
0
)
->
narrow
({{
1
,
0
,
r
}}));
full_v
->
narrow
({{
1
,
nblocks
*
block_size_
,
r
}})
->
copy_from
(
v_cache_layer
->
narrow
({{
0
,
bid
,
1
}})
->
squeeze
(
0
)
->
narrow
({{
1
,
0
,
r
}}));
}
return
{
full_k
,
full_v
};
}
}
// namespace infinilm::cache
}
// namespace infinilm::cache
csrc/cache/kv_cache.hpp
View file @
bf74389d
...
@@ -113,7 +113,7 @@ public:
...
@@ -113,7 +113,7 @@ public:
/**
/**
* @brief Update Paged KV cache at a given layer given slot info for each token.
* @brief Update Paged KV cache at a given layer given slot info for each token.
*
*
* @param layer_idx Which
transformer
layer
* @param layer_idx Which
paged attention
layer
* @param k [num_rank_k_heads, seq_len, k_dim]
* @param k [num_rank_k_heads, seq_len, k_dim]
* @param v [num_rank_v_heads, seq_len, v_dim]
* @param v [num_rank_v_heads, seq_len, v_dim]
* @param slot_mapping [seq_len]
* @param slot_mapping [seq_len]
...
@@ -128,7 +128,41 @@ public:
...
@@ -128,7 +128,41 @@ public:
const
infinicore
::
Tensor
&
v
,
const
infinicore
::
Tensor
&
v
,
const
infinicore
::
Tensor
&
slot_mapping
);
const
infinicore
::
Tensor
&
slot_mapping
);
~
PagedKVCache
()
override
=
default
;
/**
* @brief Get Paged KV cache at a given layer.
*
* @param layer_idx Which paged attention layer
*
* @return (full_k, full_v)
* full_k: [num_blocks, num_rank_k_heads, block_size, k_dim]
* full_v: [num_blocks, num_rank_v_heads, block_size, v_dim]
*/
std
::
tuple
<
infinicore
::
Tensor
,
infinicore
::
Tensor
>
get_paged_kv
(
size_t
layer_idx
);
/**
* @brief Get contiguous KV cache at a given layer, given the request info
* among a continuous request batch.
*
* @param layer_idx Which paged attention layer
* @param block_tables [num_requests, max_blocks_per_request]
* @param cache_lens [num_requests]
* @param input_offsets [num_requests + 1]
* @param request_id Which request among a continuous batch of requests
*
* @return (full_k, full_v)
* full_k: [num_rank_k_heads, total_len, k_dim]
* full_v: [num_rank_v_heads, total_len, v_dim]
*/
std
::
tuple
<
infinicore
::
Tensor
,
infinicore
::
Tensor
>
get_contiguous_kv
(
size_t
layer_idx
,
const
infinicore
::
Tensor
block_tables
,
const
infinicore
::
Tensor
cache_lens
,
const
infinicore
::
Tensor
input_offsets
,
size_t
request_id
=
0
);
~
PagedKVCache
()
override
=
default
;
private:
private:
infinicore
::
Size
k_dim_
;
infinicore
::
Size
k_dim_
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment