Unverified Commit d0239867 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #156 from InfiniTensor/issue/125_pzz

issue/125 统一Cache接口
parents 13a4154a ff00b5c8
...@@ -7,6 +7,9 @@ ...@@ -7,6 +7,9 @@
namespace infinilm { namespace infinilm {
class InfinilmModelFactory { class InfinilmModelFactory {
public: public:
static std::shared_ptr<InfinilmModel> createModel(const InfinilmModel::Config &config, engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), std::shared_ptr<cache::DynamicCache> cache_ptr = nullptr); static std::shared_ptr<InfinilmModel> createModel(
const InfinilmModel::Config &config,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
const cache::CacheConfig *cache = nullptr);
}; };
} // namespace infinilm } // namespace infinilm
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include "cache/cache.hpp"
#include "engine/engine.hpp"
#include "models/llama.hpp" #include "models/llama.hpp"
#include "engine.hpp"
namespace py = pybind11; namespace py = pybind11;
PYBIND11_MODULE(_infinilm, m) { PYBIND11_MODULE(_infinilm, m) {
m.doc() = "InfiniLM Llama model Python bindings"; m.doc() = "InfiniLM Llama model Python bindings";
infinilm::cache::bind_cache_config(m); infinilm::cache::bind_cache(m);
infinilm::models::llama::bind_llama(m); infinilm::models::llama::bind_llama(m);
infinilm::engine::distributed::bind_dist_config(m); infinilm::engine::distributed::bind_dist_config(m);
infinilm::engine::bind_infer_engine(m); infinilm::engine::bind_infer_engine(m);
......
#include "../../cache/cache.hpp"
#include "infinicore/tensor.hpp"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace py = pybind11;
namespace infinilm::cache {
inline void bind_cache(py::module &m) {
py::class_<infinilm::cache::CacheConfig,
std::shared_ptr<infinilm::cache::CacheConfig>>(m, "CacheConfig")
.def("__repr__", [](const infinilm::cache::CacheConfig &) {
return "<CacheConfig (abstract)>";
});
py::class_<infinilm::cache::StaticKVCacheConfig,
infinilm::cache::CacheConfig,
std::shared_ptr<infinilm::cache::StaticKVCacheConfig>>(m, "StaticKVCacheConfig")
.def(
py::init<infinicore::Size, infinicore::Size>(),
py::arg("max_batch_size") = 1,
py::arg("max_cache_len") = std::numeric_limits<infinicore::Size>::max())
.def(
"max_batch_size",
&infinilm::cache::StaticKVCacheConfig::max_batch_size)
.def(
"max_cache_len",
&infinilm::cache::StaticKVCacheConfig::max_cache_len)
.def("__repr__", [](const infinilm::cache::StaticKVCacheConfig &) {
return "<StaticKVCacheConfig>";
});
}
} // namespace infinilm::cache
\ No newline at end of file
#include "../cache/cache_config.hpp" #include "../../engine/infer_engine.hpp"
#include "../engine/infer_engine.hpp"
#include "infinicore/tensor.hpp" #include "infinicore/tensor.hpp"
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
namespace py = pybind11; namespace py = pybind11;
namespace infinilm::cache {
inline void bind_cache_config(py::module &m) {
// First bind the CacheType enum
py::enum_<CacheType>(m, "CacheType")
.value("DYNAMIC", CacheType::DYNAMIC)
.value("PAGED", CacheType::PAGED)
.export_values();
// Then bind the CacheResetMode enum
py::enum_<CacheResetMode>(m, "CacheResetMode")
.value("PRESERVE", CacheResetMode::PRESERVE)
.value("RECREATE", CacheResetMode::RECREATE)
.export_values();
// Finally bind the CacheConfig struct
py::class_<CacheConfig>(m, "CacheConfig")
.def(py::init<>(), "Default constructor")
.def(py::init<CacheType, size_t, size_t>(),
py::arg("type") = CacheType::DYNAMIC,
py::arg("num_layers") = 32,
py::arg("max_kv_cache_length") = 4096,
"Constructor with parameters")
.def_readwrite("type", &CacheConfig::type, "Cache type")
.def_readwrite("num_layers", &CacheConfig::num_layers, "Number of layers")
.def_readwrite("max_kv_cache_length", &CacheConfig::max_kv_cache_length,
"Maximum KV cache length")
.def_readwrite("initial_capacity", &CacheConfig::initial_capacity,
"Initial cache capacity in tokens")
.def_readwrite("initial_batch_size", &CacheConfig::initial_batch_size,
"Initial batch size for cache allocation")
.def_readwrite("growth_factor", &CacheConfig::growth_factor,
"Cache growth factor when resizing (e.g., 2.0 for doubling)")
.def_readwrite("allow_expand", &CacheConfig::allow_expand,
"Whether to allow cache expansion")
.def_readwrite("reset_mode", &CacheConfig::reset_mode,
"Cache reset mode")
.def("__eq__", &CacheConfig::operator==, py::is_operator(),
"Check if two CacheConfig objects are equal")
.def("__ne__", &CacheConfig::operator!=, py::is_operator(),
"Check if two CacheConfig objects are not equal")
.def("__repr__", [](const CacheConfig &cfg) {
return fmt::format("CacheConfig(type={}, num_layers={}, max_kv_cache_length={}, "
"initial_capacity={}, initial_batch_size={}, growth_factor={}, "
"allow_expand={}, reset_mode={})",
static_cast<int>(cfg.type), cfg.num_layers,
cfg.max_kv_cache_length, cfg.initial_capacity,
cfg.initial_batch_size, cfg.growth_factor,
cfg.allow_expand, static_cast<int>(cfg.reset_mode));
});
}
} // namespace infinilm::cache
namespace infinilm::engine::distributed { namespace infinilm::engine::distributed {
inline void bind_dist_config(py::module &m) { inline void bind_dist_config(py::module &m) {
...@@ -86,19 +31,24 @@ namespace infinilm::engine { ...@@ -86,19 +31,24 @@ namespace infinilm::engine {
inline void bind_infer_engine(py::module &m) { inline void bind_infer_engine(py::module &m) {
py::class_<InferEngine, std::shared_ptr<InferEngine>> infer_engine(m, "InferEngine"); py::class_<InferEngine, std::shared_ptr<InferEngine>> infer_engine(m, "InferEngine");
infer_engine infer_engine
.def(py::init([](const InfinilmModel::Config &cfg, .def(py::init([](
const infinilm::engine::distributed::DistConfig &dist, const InfinilmModel::Config &cfg,
infinicore::Device::Type dev, const distributed::DistConfig &dist,
const infinilm::cache::CacheConfig &cache_config) { infinicore::Device::Type dev,
return new InferEngine(cfg, dist, dev, cache_config); std::shared_ptr<const infinilm::cache::CacheConfig> cache_cfg) {
return std::make_shared<InferEngine>(
cfg,
dist,
dev,
cache_cfg ? cache_cfg.get() : nullptr);
}), }),
py::arg("config"), py::arg("config"),
py::arg("distributed_config") = distributed::DistConfig(), py::arg("distributed_config") = distributed::DistConfig(),
py::arg("device_type") = infinicore::context::getDevice().getType(), py::arg("device_type") = infinicore::context::getDevice().getType(),
py::arg("cache_config") = infinilm::cache::CacheConfig()) py::arg("cache_config") = py::none())
.def("load_param", &InferEngine::load_param, .def("load_param", &InferEngine::load_param,
py::arg("name"), py::arg("param"), py::arg("name"), py::arg("param"),
"Load a parameter tensor into all workers (each worker picks its shard)") "Load a parameter tensor into all workers (each worker picks its shard)")
.def("state_dict", [](InferEngine &self) { .def("state_dict", [](InferEngine &self) {
py::list state_dict_tp_all; py::list state_dict_tp_all;
for (const auto &state_dict_tp : self.state_dict()) { for (const auto &state_dict_tp : self.state_dict()) {
...@@ -110,19 +60,26 @@ inline void bind_infer_engine(py::module &m) { ...@@ -110,19 +60,26 @@ inline void bind_infer_engine(py::module &m) {
} }
return state_dict_tp_all; return state_dict_tp_all;
}) })
.def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments") .def(
.def("reset_cache", py::overload_cast<size_t>(&InferEngine::reset_cache), py::arg("pos") = 0, "Reset the internal cache in all workers to a specific position") "forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
.def("reset_cache", py::overload_cast<const cache::CacheConfig &, size_t>(&InferEngine::reset_cache), py::arg("cache_config"), py::arg("pos") = 0, "Reset cache with new KV configuration") .def(
.def("get_cache_config", &InferEngine::get_cache_config, "Get current KV configuration") "reset_cache", [](InferEngine &self, std::shared_ptr<const cache::CacheConfig> cfg) {
self.reset_cache(cfg ? cfg.get() : nullptr);
},
py::arg("cache_config") = py::none())
.def("get_cache_config", [](const InferEngine &self) {
auto cfg = self.get_cache_config();
return std::shared_ptr<cache::CacheConfig>(std::move(cfg->unique_copy()));
})
.def("__repr__", [](const InferEngine &self) { .def("__repr__", [](const InferEngine &self) {
return "<InferEngine: " + std::string(self.get_dist_config()) + ">"; return "<InferEngine: " + std::string(self.get_dist_config()) + ">";
}); });
py::class_<InferEngine::Input>(infer_engine, "Input") py::class_<InferEngine::Input>(infer_engine, "Input")
.def(py::init([](const infinicore::Tensor &input_ids, const infinicore::Tensor &position_ids) { .def(py::init([](const infinicore::Tensor &input_ids, const infinicore::Tensor &position_ids, const infinicore::Tensor &cache_positions) {
return new InferEngine::Input{input_ids, position_ids}; return new InferEngine::Input{input_ids, position_ids, cache_positions};
}), }),
py::arg("input_ids"), py::arg("position_ids")); py::arg("input_ids"), py::arg("position_ids"), py::arg("cache_positions"));
py::class_<InferEngine::Output>(infer_engine, "Output") py::class_<InferEngine::Output>(infer_engine, "Output")
.def_readwrite("logits", &InferEngine::Output::logits, "Output tensor"); .def_readwrite("logits", &InferEngine::Output::logits, "Output tensor");
......
...@@ -3,22 +3,23 @@ ...@@ -3,22 +3,23 @@
#include <cstring> #include <cstring>
#include <iostream> #include <iostream>
#include <spdlog/spdlog.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
inline void assertTrue(int expr, const char *msg, const char *file, int line) { inline void assertTrue(int expr, const char *msg, const char *function, const char *file, int line) {
if (!expr) { if (!expr) {
fprintf(stderr, "\033[31mAssertion failed:\033[0m %s at file %s, line %d\n", msg, file, line); spdlog::error("Assertion failed: {} in function {} at file {}, line {}", msg, function, file, line);
exit(EXIT_FAILURE); throw std::runtime_error("Assertion failed");
} }
} }
#define ASSERT(expr) assertTrue((expr), #expr " is false", __FILE__, __LINE__) #define ASSERT(expr) assertTrue((expr), #expr " is false", __func__, __FILE__, __LINE__)
#define ASSERT_EQ(a, b) assertTrue((a) == (b), #a " != " #b, __FILE__, __LINE__) #define ASSERT_EQ(a, b) assertTrue((a) == (b), #a " != " #b, __func__, __FILE__, __LINE__)
#define ASSERT_VALID_PTR(a) assertTrue((a) != nullptr, #a " is nullptr", __FILE__, __LINE__) #define ASSERT_VALID_PTR(a) assertTrue((a) != nullptr, #a " is nullptr", __func__, __FILE__, __LINE__)
#define PANIC(EXPR) \ #define PANIC(EXPR) \
printf("Error at %s:%d - %s\n", __FILE__, __LINE__, #EXPR); \ spdlog::error("Error at {} in function {} at file {}, line {}", #EXPR, function, file, line); \
exit(EXIT_FAILURE) exit(EXIT_FAILURE)
#define RUN_INFINI(API) \ #define RUN_INFINI(API) \
...@@ -28,7 +29,7 @@ inline void assertTrue(int expr, const char *msg, const char *file, int line) { ...@@ -28,7 +29,7 @@ inline void assertTrue(int expr, const char *msg, const char *file, int line) {
std::cerr << "Error Code " << api_result_ << " in `" << #API << "`" \ std::cerr << "Error Code " << api_result_ << " in `" << #API << "`" \
<< " from " << __func__ \ << " from " << __func__ \
<< " at " << __FILE__ << ":" << __LINE__ << std::endl; \ << " at " << __FILE__ << ":" << __LINE__ << std::endl; \
exit(EXIT_FAILURE); \ throw std::runtime_error("InfiniCore C API Error"); \
} \ } \
} while (0) } while (0)
......
...@@ -317,7 +317,7 @@ if __name__ == "__main__": ...@@ -317,7 +317,7 @@ if __name__ == "__main__":
# reset cache for each case # reset cache for each case
initial_capacity = input_len + output_len + 100 initial_capacity = input_len + output_len + 100
test.model.reset_cache( test.model.reset_cache(
batch_size=batch_size, pos=0, initial_capacity=initial_capacity batch_size=batch_size, initial_capacity=initial_capacity
) )
# run test one case # run test one case
......
...@@ -141,14 +141,21 @@ def test( ...@@ -141,14 +141,21 @@ def test(
) )
for prompt in prompts for prompt in prompts
] ]
print(input_contents[0], end="", flush=True)
input_ids_list = tokenizer.batch_encode_plus(input_contents)[ input_ids_list = tokenizer.batch_encode_plus(input_contents)[
"input_ids" "input_ids"
] # List: [[1, 1128, 526, 366, 29892]] ] # List: [[1, 1128, 526, 366, 29892]]
# 根据输入长度和最长输出长度创建KVCache
model.reset_cache(
1 if prompts is str else len(prompts),
max_new_tokens + len(input_ids_list[0]),
)
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
# 自回归生成 # 自回归生成
# ---------------------------------------------------------------------------- # # ---------------------------------------------------------------------------- #
print(input_contents[0], end="", flush=True)
input_ids_infini = infinicore.from_list(input_ids_list) input_ids_infini = infinicore.from_list(input_ids_list)
t1 = time.time() t1 = time.time()
......
from .models import AutoLlamaModel from .models import AutoLlamaModel
from . import distributed from . import distributed
from . import cache
__all__ = ["AutoLlamaModel", "distributed"] __all__ = ["AutoLlamaModel", "distributed", "cache"]
from .cache import CacheConfig, StaticKVCacheConfig
__all__ = ["CacheConfig", "StaticKVCacheConfig"]
from infinilm.lib import _infinilm
class CacheConfig(_infinilm.CacheConfig):
def __init__(self):
raise NotImplementedError(
"CacheConfig is an abstract class. Please use a subclass of CacheConfig."
)
class StaticKVCacheConfig(CacheConfig, _infinilm.StaticKVCacheConfig):
def __init__(self, max_batch_size: int = 1, max_cache_len: int = 4096):
_infinilm.StaticKVCacheConfig.__init__(self, max_batch_size, max_cache_len)
...@@ -69,16 +69,18 @@ class GenerationMixin: ...@@ -69,16 +69,18 @@ class GenerationMixin:
model_inputs["past_key_values"] = past_key_values model_inputs["past_key_values"] = past_key_values
# -------------------------------------------------------------------------- # # -------------------------------------------------------------------------- #
# 计算所需的position_ids # 计算所需的: position_ids
# -------------------------------------------------------------------------- # # -------------------------------------------------------------------------- #
current_position_ids = kwargs.get("position_ids", None) current_position_ids = kwargs.get("position_ids", None)
if current_position_ids is None: if current_position_ids is None:
# prill阶段 # prill阶段
bs, seq_len = kwargs["input_ids"].shape[0:2] bs, seq_len = kwargs["input_ids"].shape[0:2]
model_inputs["position_ids"] = self._get_initial_position_ids(bs, seq_len) model_inputs["position_ids"] = self._get_initial_position_ids(bs, seq_len)
model_inputs["cache_positions"] = infinicore.from_list(
[0], dtype=infinicore.int64
)
else: else:
# decoder 阶段 # decode 阶段
bs, seq_len = current_position_ids.shape bs, seq_len = current_position_ids.shape
last_position = current_position_ids.narrow(1, seq_len - 1, 1) last_position = current_position_ids.narrow(1, seq_len - 1, 1)
...@@ -90,7 +92,13 @@ class GenerationMixin: ...@@ -90,7 +92,13 @@ class GenerationMixin:
next_position = one_value + last_position next_position = one_value + last_position
model_inputs["position_ids"] = next_position model_inputs["position_ids"] = next_position
model_inputs["cache_positions"] = kwargs[
"cache_positions"
] + infinicore.from_list(
[seq_len],
dtype=last_position.dtype,
device=last_position.device,
)
# -------------------------------------------------------------------- # # -------------------------------------------------------------------- #
# 所需的: token的input_ids # 所需的: token的input_ids
# -------------------------------------------------------------------- # # -------------------------------------------------------------------- #
...@@ -119,29 +127,14 @@ class GenerationMixin: ...@@ -119,29 +127,14 @@ class GenerationMixin:
): ):
model_kwargs = kwargs model_kwargs = kwargs
# -------------------------------------------------------------------- #
# CRITICAL: Reset internal cache before each new generation
# This prevents state from persisting between different questions/prompts
# -------------------------------------------------------------------- #
# Check if this is a cpp backend model (has _model attribute with reset_cache method) # Check if this is a cpp backend model (has _model attribute with reset_cache method)
if hasattr(self, "_model") and hasattr(self._model, "reset_cache"): if not (hasattr(self, "_model") and hasattr(self._model, "reset_cache")):
try: if self.use_cache:
self._model.reset_cache() model_kwargs["use_cache"] = True
except Exception as e: model_kwargs["past_key_values"] = DynamicCache(config=self.config)
# If reset_cache fails, log but continue (shouldn't happen) else:
import warnings model_kwargs["use_cache"] = False
model_kwargs["past_key_values"] = None
warnings.warn(f"Failed to reset cache: {e}")
# -------------------------------------------------------------------- #
# 创建 cache #
# -------------------------------------------------------------------- #
if self.use_cache:
model_kwargs["use_cache"] = True
model_kwargs["past_key_values"] = DynamicCache(config=self.config)
else:
model_kwargs["use_cache"] = False
model_kwargs["past_key_values"] = None
# -------------------------------------------------------------------- # # -------------------------------------------------------------------- #
# _sample函数 # # _sample函数 #
...@@ -204,6 +197,7 @@ class GenerationMixin: ...@@ -204,6 +197,7 @@ class GenerationMixin:
model_inputs = self.prepare_inputs_for_generation(**model_kwargs) model_inputs = self.prepare_inputs_for_generation(**model_kwargs)
model_kwargs["position_ids"] = model_inputs["position_ids"] model_kwargs["position_ids"] = model_inputs["position_ids"]
model_kwargs["cache_positions"] = model_inputs["cache_positions"]
# -------------------------------------------------------------------------- # # -------------------------------------------------------------------------- #
# 计算一次 # 计算一次
......
...@@ -2,11 +2,11 @@ from ....generation.utils import GenerationMixin ...@@ -2,11 +2,11 @@ from ....generation.utils import GenerationMixin
import infinicore import infinicore
from infinilm.models.llama.configuration_llama import LlamaConfig from infinilm.models.llama.configuration_llama import LlamaConfig
from infinilm.lib import _infinilm from infinilm.lib import _infinilm
from infinilm.cache import StaticKVCacheConfig
from infinilm.distributed import DistConfig from infinilm.distributed import DistConfig
import json import json
import os import os
from typing import Optional, Union from typing import Optional, Union
from collections import OrderedDict
class LlamaForCausalLM(GenerationMixin): class LlamaForCausalLM(GenerationMixin):
...@@ -18,6 +18,7 @@ class LlamaForCausalLM(GenerationMixin): ...@@ -18,6 +18,7 @@ class LlamaForCausalLM(GenerationMixin):
device=None, device=None,
dtype=None, dtype=None,
distributed_config=DistConfig(1), distributed_config=DistConfig(1),
cache_config=None,
): ):
""" """
Create LlamaForCausalLM Create LlamaForCausalLM
...@@ -51,18 +52,19 @@ class LlamaForCausalLM(GenerationMixin): ...@@ -51,18 +52,19 @@ class LlamaForCausalLM(GenerationMixin):
# config._underlying, device._underlying, dtype # config._underlying, device._underlying, dtype
# ) # )
self._model = _infinilm.InferEngine( self._model = _infinilm.InferEngine(
config, distributed_config._underlying, device._underlying.type config,
distributed_config._underlying,
device._underlying.type,
cache_config,
) )
def reset_cache(self, batch_size: int, pos: int = 0, initial_capacity: int = 1024): def reset_cache(self, batch_size: int, initial_capacity: int = 1024):
"""Reset the cache for the model""" """Reset the cache for the model"""
infinicore.sync_device() infinicore.sync_device()
cache_config = self._model.get_cache_config() cache_config = StaticKVCacheConfig(batch_size, initial_capacity)
cache_config.initial_batch_size = batch_size
cache_config.initial_capacity = initial_capacity
self._model.reset_cache(cache_config, pos) self._model.reset_cache(cache_config)
def state_dict_keyname(self): def state_dict_keyname(self):
"""Get model key name.""" """Get model key name."""
...@@ -102,19 +104,25 @@ class LlamaForCausalLM(GenerationMixin): ...@@ -102,19 +104,25 @@ class LlamaForCausalLM(GenerationMixin):
# like get_text_config() used by DynamicCache # like get_text_config() used by DynamicCache
return self._config return self._config
def forward(self, input_ids, position_ids, *args, **kwargs): def forward(self, input_ids, position_ids, cache_positions, *args, **kwargs):
kv_caches = None
# return infinicore.Tensor(
# self._model.forward(input_ids, position_ids, kv_caches)
# )
return infinicore.Tensor( return infinicore.Tensor(
self._model.forward( self._model.forward(
self._model.Input(input_ids._underlying, position_ids._underlying) self._model.Input(
input_ids._underlying,
position_ids._underlying,
cache_positions._underlying,
)
).logits ).logits
) )
def __call__(self, input_ids, position_ids, *args, **kwargs): def __call__(self, input_ids, position_ids, cache_positions, *args, **kwargs):
return self.forward(input_ids=input_ids, position_ids=position_ids) return self.forward(
input_ids=input_ids,
position_ids=position_ids,
cache_positions=cache_positions,
*args,
**kwargs,
)
@classmethod @classmethod
def from_pretrained( def from_pretrained(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment