Commit db19cc0b authored by PanZezhong's avatar PanZezhong
Browse files

issue/168 use n_blocks to init paged kv cache config, support fixed paged caching api

parent 831e8a67
...@@ -111,9 +111,9 @@ StaticKVCache::update(size_t layer_idx, ...@@ -111,9 +111,9 @@ StaticKVCache::update(size_t layer_idx,
// PagedKVCacheConfig // PagedKVCacheConfig
// ========================== // ==========================
PagedKVCacheConfig::PagedKVCacheConfig( PagedKVCacheConfig::PagedKVCacheConfig(
size_t max_kv_memory_bytes, size_t num_blocks,
size_t block_size) size_t block_size)
: max_kv_memory_bytes_(max_kv_memory_bytes), : num_blocks_(num_blocks),
block_size_(block_size) { block_size_(block_size) {
} }
...@@ -123,8 +123,8 @@ PagedKVCacheConfig::unique_copy() const { ...@@ -123,8 +123,8 @@ PagedKVCacheConfig::unique_copy() const {
} }
size_t size_t
PagedKVCacheConfig::max_kv_memory_bytes() const { PagedKVCacheConfig::num_blocks() const {
return max_kv_memory_bytes_; return num_blocks_;
} }
size_t size_t
...@@ -151,16 +151,8 @@ PagedKVCache::PagedKVCache( ...@@ -151,16 +151,8 @@ PagedKVCache::PagedKVCache(
num_rank_v_heads_(num_v_heads / rank_info.tp_size), num_rank_v_heads_(num_v_heads / rank_info.tp_size),
rank_num_layers_(num_layers), rank_num_layers_(num_layers),
dtype_(dtype), dtype_(dtype),
num_blocks_per_layer_(config.num_blocks()),
block_size_(config.block_size()) { block_size_(config.block_size()) {
num_blocks_per_layer_ = config.max_kv_memory_bytes()
/ (k_dim * num_rank_k_heads_ + v_dim * num_rank_v_heads_)
/ block_size_
/ rank_num_layers_
/ infinicore::dsize(dtype_);
if (num_blocks_per_layer_ == 0) {
throw std::runtime_error("Not enough memory for KV cache");
}
// [num_layers, num_blocks, num_rank_k_heads, block_size, k_dim] // [num_layers, num_blocks, num_rank_k_heads, block_size, k_dim]
k_caches_ = infinicore::Tensor::empty( k_caches_ = infinicore::Tensor::empty(
{rank_num_layers_, {rank_num_layers_,
...@@ -190,11 +182,12 @@ std::tuple<infinicore::Tensor, infinicore::Tensor> PagedKVCache::update( ...@@ -190,11 +182,12 @@ std::tuple<infinicore::Tensor, infinicore::Tensor> PagedKVCache::update(
auto &&[k_cache_layer, v_cache_layer] = get_paged_kv(layer_idx); auto &&[k_cache_layer, v_cache_layer] = get_paged_kv(layer_idx);
infinicore::op::paged_caching_(k, infinicore::op::paged_caching_(
v, k_cache_layer,
k_cache_layer, v_cache_layer,
v_cache_layer, k,
slot_mapping); v,
slot_mapping);
return {k_cache_layer, v_cache_layer}; return {k_cache_layer, v_cache_layer};
} }
......
...@@ -85,15 +85,15 @@ private: ...@@ -85,15 +85,15 @@ private:
class PagedKVCacheConfig final : public CacheConfig { class PagedKVCacheConfig final : public CacheConfig {
public: public:
PagedKVCacheConfig( PagedKVCacheConfig(
size_t max_kv_memory_bytes, size_t num_blocks,
size_t block_size = 16); size_t block_size = 16);
std::unique_ptr<CacheConfig> unique_copy() const override; std::unique_ptr<CacheConfig> unique_copy() const override;
size_t max_kv_memory_bytes() const; size_t num_blocks() const;
size_t block_size() const; size_t block_size() const;
private: private:
size_t max_kv_memory_bytes_; size_t num_blocks_;
size_t block_size_; size_t block_size_;
}; };
......
...@@ -36,11 +36,11 @@ inline void bind_cache(py::module &m) { ...@@ -36,11 +36,11 @@ inline void bind_cache(py::module &m) {
std::shared_ptr<infinilm::cache::PagedKVCacheConfig>>(m, "PagedKVCacheConfig") std::shared_ptr<infinilm::cache::PagedKVCacheConfig>>(m, "PagedKVCacheConfig")
.def( .def(
py::init<size_t, size_t>(), py::init<size_t, size_t>(),
py::arg("max_kv_memory_bytes"), py::arg("num_blocks"),
py::arg("block_size") = 16) py::arg("block_size") = 16)
.def( .def(
"max_kv_memory_bytes", "num_blocks",
&infinilm::cache::PagedKVCacheConfig::max_kv_memory_bytes) &infinilm::cache::PagedKVCacheConfig::num_blocks)
.def( .def(
"block_size", "block_size",
&infinilm::cache::PagedKVCacheConfig::block_size) &infinilm::cache::PagedKVCacheConfig::block_size)
......
...@@ -89,13 +89,6 @@ def get_args(): ...@@ -89,13 +89,6 @@ def get_args():
help="use paged cache", help="use paged cache",
) )
parser.add_argument(
"--max-kvcache-size",
type=int,
default=8 * 1024 * 1024 * 1024,
help="max size (in bytes) allocated to paged kv cache",
)
return parser.parse_args() return parser.parse_args()
...@@ -109,7 +102,7 @@ def test( ...@@ -109,7 +102,7 @@ def test(
): ):
model_path = os.path.expanduser(model_path) model_path = os.path.expanduser(model_path)
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# 创建模型, # Create Model
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
model = InferEngine( model = InferEngine(
model_path, model_path,
...@@ -118,12 +111,12 @@ def test( ...@@ -118,12 +111,12 @@ def test(
) )
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# 加载权重 # Load Weights
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
load_model_state_dict_by_file(model, model_path, dtype=model.config.dtype) load_model_state_dict_by_file(model, model_path, dtype=model.config.dtype)
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# 创建 tokenizer # create tokenizer
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
...@@ -146,7 +139,7 @@ def test( ...@@ -146,7 +139,7 @@ def test(
) )
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# token编码 # tokenize
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# prompt = "山东最高的山是?" # prompt = "山东最高的山是?"
if isinstance(prompts, str): if isinstance(prompts, str):
...@@ -165,11 +158,13 @@ def test( ...@@ -165,11 +158,13 @@ def test(
] # List: [[1, 1128, 526, 366, 29892]] ] # List: [[1, 1128, 526, 366, 29892]]
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# 创建KVCache # Create KVCache
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
if enable_paged_attn: if enable_paged_attn:
batch_size = 1 if prompts is str else len(prompts)
max_total_tokens = max_new_tokens + len(input_ids_list[0])
cache_config = PagedKVCacheConfig( cache_config = PagedKVCacheConfig(
max_kv_memory_bytes=args.max_kvcache_size, block_size=16 num_blocks=(max_total_tokens // 16 + 1) * batch_size, block_size=16
) )
else: else:
batch_size = 1 if prompts is str else len(prompts) batch_size = 1 if prompts is str else len(prompts)
...@@ -181,7 +176,7 @@ def test( ...@@ -181,7 +176,7 @@ def test(
model.reset_cache(cache_config) model.reset_cache(cache_config)
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# 自回归生成 # Generate
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
print(input_contents[0], end="", flush=True) print(input_contents[0], end="", flush=True)
input_ids_infini = infinicore.from_list(input_ids_list) input_ids_infini = infinicore.from_list(input_ids_list)
......
...@@ -16,11 +16,11 @@ class StaticKVCacheConfig(CacheConfig, _infinilm.StaticKVCacheConfig): ...@@ -16,11 +16,11 @@ class StaticKVCacheConfig(CacheConfig, _infinilm.StaticKVCacheConfig):
class PagedKVCacheConfig(CacheConfig, _infinilm.PagedKVCacheConfig): class PagedKVCacheConfig(CacheConfig, _infinilm.PagedKVCacheConfig):
def __init__( def __init__(
self, self,
max_kv_memory_bytes: int, num_blocks: int,
block_size: int = 16, block_size: int = 16,
): ):
_infinilm.PagedKVCacheConfig.__init__( _infinilm.PagedKVCacheConfig.__init__(
self, self,
max_kv_memory_bytes, num_blocks,
block_size, block_size,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment