Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinilm
Commits
db19cc0b
Commit
db19cc0b
authored
Jan 09, 2026
by
PanZezhong
Browse files
issue/168 use n_blocks to init paged kv cache config, support fixed paged caching api
parent
831e8a67
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
28 additions
and
40 deletions
+28
-40
csrc/cache/kv_cache.cpp
csrc/cache/kv_cache.cpp
+11
-18
csrc/cache/kv_cache.hpp
csrc/cache/kv_cache.hpp
+3
-3
csrc/pybind11/cache/cache.hpp
csrc/pybind11/cache/cache.hpp
+3
-3
examples/jiuge.py
examples/jiuge.py
+9
-14
python/infinilm/cache/cache.py
python/infinilm/cache/cache.py
+2
-2
No files found.
csrc/cache/kv_cache.cpp
View file @
db19cc0b
...
...
@@ -111,9 +111,9 @@ StaticKVCache::update(size_t layer_idx,
// PagedKVCacheConfig
// ==========================
PagedKVCacheConfig
::
PagedKVCacheConfig
(
size_t
max_kv_memory_byte
s
,
size_t
num_block
s
,
size_t
block_size
)
:
max_kv_memory_bytes_
(
max_kv_memory_byte
s
),
:
num_blocks_
(
num_block
s
),
block_size_
(
block_size
)
{
}
...
...
@@ -123,8 +123,8 @@ PagedKVCacheConfig::unique_copy() const {
}
size_t
PagedKVCacheConfig
::
max_kv_memory_byte
s
()
const
{
return
max_kv_memory_byte
s_
;
PagedKVCacheConfig
::
num_block
s
()
const
{
return
num_block
s_
;
}
size_t
...
...
@@ -151,16 +151,8 @@ PagedKVCache::PagedKVCache(
num_rank_v_heads_
(
num_v_heads
/
rank_info
.
tp_size
),
rank_num_layers_
(
num_layers
),
dtype_
(
dtype
),
num_blocks_per_layer_
(
config
.
num_blocks
()),
block_size_
(
config
.
block_size
())
{
num_blocks_per_layer_
=
config
.
max_kv_memory_bytes
()
/
(
k_dim
*
num_rank_k_heads_
+
v_dim
*
num_rank_v_heads_
)
/
block_size_
/
rank_num_layers_
/
infinicore
::
dsize
(
dtype_
);
if
(
num_blocks_per_layer_
==
0
)
{
throw
std
::
runtime_error
(
"Not enough memory for KV cache"
);
}
// [num_layers, num_blocks, num_rank_k_heads, block_size, k_dim]
k_caches_
=
infinicore
::
Tensor
::
empty
(
{
rank_num_layers_
,
...
...
@@ -190,11 +182,12 @@ std::tuple<infinicore::Tensor, infinicore::Tensor> PagedKVCache::update(
auto
&&
[
k_cache_layer
,
v_cache_layer
]
=
get_paged_kv
(
layer_idx
);
infinicore
::
op
::
paged_caching_
(
k
,
v
,
k_cache_layer
,
v_cache_layer
,
slot_mapping
);
infinicore
::
op
::
paged_caching_
(
k_cache_layer
,
v_cache_layer
,
k
,
v
,
slot_mapping
);
return
{
k_cache_layer
,
v_cache_layer
};
}
...
...
csrc/cache/kv_cache.hpp
View file @
db19cc0b
...
...
@@ -85,15 +85,15 @@ private:
class
PagedKVCacheConfig
final
:
public
CacheConfig
{
public:
PagedKVCacheConfig
(
size_t
max_kv_memory_byte
s
,
size_t
num_block
s
,
size_t
block_size
=
16
);
std
::
unique_ptr
<
CacheConfig
>
unique_copy
()
const
override
;
size_t
max_kv_memory_byte
s
()
const
;
size_t
num_block
s
()
const
;
size_t
block_size
()
const
;
private:
size_t
max_kv_memory_byte
s_
;
size_t
num_block
s_
;
size_t
block_size_
;
};
...
...
csrc/pybind11/cache/cache.hpp
View file @
db19cc0b
...
...
@@ -36,11 +36,11 @@ inline void bind_cache(py::module &m) {
std
::
shared_ptr
<
infinilm
::
cache
::
PagedKVCacheConfig
>>
(
m
,
"PagedKVCacheConfig"
)
.
def
(
py
::
init
<
size_t
,
size_t
>
(),
py
::
arg
(
"
max_kv_memory_byte
s"
),
py
::
arg
(
"
num_block
s"
),
py
::
arg
(
"block_size"
)
=
16
)
.
def
(
"
max_kv_memory_byte
s"
,
&
infinilm
::
cache
::
PagedKVCacheConfig
::
max_kv_memory_byte
s
)
"
num_block
s"
,
&
infinilm
::
cache
::
PagedKVCacheConfig
::
num_block
s
)
.
def
(
"block_size"
,
&
infinilm
::
cache
::
PagedKVCacheConfig
::
block_size
)
...
...
examples/jiuge.py
View file @
db19cc0b
...
...
@@ -89,13 +89,6 @@ def get_args():
help
=
"use paged cache"
,
)
parser
.
add_argument
(
"--max-kvcache-size"
,
type
=
int
,
default
=
8
*
1024
*
1024
*
1024
,
help
=
"max size (in bytes) allocated to paged kv cache"
,
)
return
parser
.
parse_args
()
...
...
@@ -109,7 +102,7 @@ def test(
):
model_path
=
os
.
path
.
expanduser
(
model_path
)
# ---------------------------------------------------------------------------- #
#
创建模型,
#
Create Model
# ---------------------------------------------------------------------------- #
model
=
InferEngine
(
model_path
,
...
...
@@ -118,12 +111,12 @@ def test(
)
# ---------------------------------------------------------------------------- #
#
加载权重
#
Load Weights
# ---------------------------------------------------------------------------- #
load_model_state_dict_by_file
(
model
,
model_path
,
dtype
=
model
.
config
.
dtype
)
# ---------------------------------------------------------------------------- #
#
创建
tokenizer
#
create
tokenizer
# ---------------------------------------------------------------------------- #
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
...
...
@@ -146,7 +139,7 @@ def test(
)
# ---------------------------------------------------------------------------- #
# token
编码
# token
ize
# ---------------------------------------------------------------------------- #
# prompt = "山东最高的山是?"
if
isinstance
(
prompts
,
str
):
...
...
@@ -165,11 +158,13 @@ def test(
]
# List: [[1, 1128, 526, 366, 29892]]
# ---------------------------------------------------------------------------- #
#
创建
KVCache
#
Create
KVCache
# ---------------------------------------------------------------------------- #
if
enable_paged_attn
:
batch_size
=
1
if
prompts
is
str
else
len
(
prompts
)
max_total_tokens
=
max_new_tokens
+
len
(
input_ids_list
[
0
])
cache_config
=
PagedKVCacheConfig
(
max_kv_memory_bytes
=
args
.
max_kvca
ch
e
_size
,
block_size
=
16
num_blocks
=
(
max_total_tokens
//
16
+
1
)
*
bat
ch_size
,
block_size
=
16
)
else
:
batch_size
=
1
if
prompts
is
str
else
len
(
prompts
)
...
...
@@ -181,7 +176,7 @@ def test(
model
.
reset_cache
(
cache_config
)
# ---------------------------------------------------------------------------- #
#
自回归生成
#
Generate
# ---------------------------------------------------------------------------- #
print
(
input_contents
[
0
],
end
=
""
,
flush
=
True
)
input_ids_infini
=
infinicore
.
from_list
(
input_ids_list
)
...
...
python/infinilm/cache/cache.py
View file @
db19cc0b
...
...
@@ -16,11 +16,11 @@ class StaticKVCacheConfig(CacheConfig, _infinilm.StaticKVCacheConfig):
class
PagedKVCacheConfig
(
CacheConfig
,
_infinilm
.
PagedKVCacheConfig
):
def
__init__
(
self
,
max_kv_memory_byte
s
:
int
,
num_block
s
:
int
,
block_size
:
int
=
16
,
):
_infinilm
.
PagedKVCacheConfig
.
__init__
(
self
,
max_kv_memory_byte
s
,
num_block
s
,
block_size
,
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment