Unverified Commit d0239867 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #156 from InfiniTensor/issue/125_pzz

issue/125 统一Cache接口
parents 13a4154a ff00b5c8
......@@ -7,6 +7,9 @@
namespace infinilm {
class InfinilmModelFactory {
public:
static std::shared_ptr<InfinilmModel> createModel(const InfinilmModel::Config &config, engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(), std::shared_ptr<cache::DynamicCache> cache_ptr = nullptr);
static std::shared_ptr<InfinilmModel> createModel(
const InfinilmModel::Config &config,
engine::distributed::RankInfo rank_info = engine::distributed::RankInfo(),
const cache::CacheConfig *cache = nullptr);
};
} // namespace infinilm
#include <pybind11/pybind11.h>
#include "cache/cache.hpp"
#include "engine/engine.hpp"
#include "models/llama.hpp"
#include "engine.hpp"
namespace py = pybind11;
PYBIND11_MODULE(_infinilm, m) {
m.doc() = "InfiniLM Llama model Python bindings";
infinilm::cache::bind_cache_config(m);
infinilm::cache::bind_cache(m);
infinilm::models::llama::bind_llama(m);
infinilm::engine::distributed::bind_dist_config(m);
infinilm::engine::bind_infer_engine(m);
......
#include "../../cache/cache.hpp"
#include "infinicore/tensor.hpp"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace py = pybind11;
namespace infinilm::cache {
inline void bind_cache(py::module &m) {
py::class_<infinilm::cache::CacheConfig,
std::shared_ptr<infinilm::cache::CacheConfig>>(m, "CacheConfig")
.def("__repr__", [](const infinilm::cache::CacheConfig &) {
return "<CacheConfig (abstract)>";
});
py::class_<infinilm::cache::StaticKVCacheConfig,
infinilm::cache::CacheConfig,
std::shared_ptr<infinilm::cache::StaticKVCacheConfig>>(m, "StaticKVCacheConfig")
.def(
py::init<infinicore::Size, infinicore::Size>(),
py::arg("max_batch_size") = 1,
py::arg("max_cache_len") = std::numeric_limits<infinicore::Size>::max())
.def(
"max_batch_size",
&infinilm::cache::StaticKVCacheConfig::max_batch_size)
.def(
"max_cache_len",
&infinilm::cache::StaticKVCacheConfig::max_cache_len)
.def("__repr__", [](const infinilm::cache::StaticKVCacheConfig &) {
return "<StaticKVCacheConfig>";
});
}
} // namespace infinilm::cache
\ No newline at end of file
#include "../cache/cache_config.hpp"
#include "../engine/infer_engine.hpp"
#include "../../engine/infer_engine.hpp"
#include "infinicore/tensor.hpp"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
namespace py = pybind11;
namespace infinilm::cache {
inline void bind_cache_config(py::module &m) {
// First bind the CacheType enum
py::enum_<CacheType>(m, "CacheType")
.value("DYNAMIC", CacheType::DYNAMIC)
.value("PAGED", CacheType::PAGED)
.export_values();
// Then bind the CacheResetMode enum
py::enum_<CacheResetMode>(m, "CacheResetMode")
.value("PRESERVE", CacheResetMode::PRESERVE)
.value("RECREATE", CacheResetMode::RECREATE)
.export_values();
// Finally bind the CacheConfig struct
py::class_<CacheConfig>(m, "CacheConfig")
.def(py::init<>(), "Default constructor")
.def(py::init<CacheType, size_t, size_t>(),
py::arg("type") = CacheType::DYNAMIC,
py::arg("num_layers") = 32,
py::arg("max_kv_cache_length") = 4096,
"Constructor with parameters")
.def_readwrite("type", &CacheConfig::type, "Cache type")
.def_readwrite("num_layers", &CacheConfig::num_layers, "Number of layers")
.def_readwrite("max_kv_cache_length", &CacheConfig::max_kv_cache_length,
"Maximum KV cache length")
.def_readwrite("initial_capacity", &CacheConfig::initial_capacity,
"Initial cache capacity in tokens")
.def_readwrite("initial_batch_size", &CacheConfig::initial_batch_size,
"Initial batch size for cache allocation")
.def_readwrite("growth_factor", &CacheConfig::growth_factor,
"Cache growth factor when resizing (e.g., 2.0 for doubling)")
.def_readwrite("allow_expand", &CacheConfig::allow_expand,
"Whether to allow cache expansion")
.def_readwrite("reset_mode", &CacheConfig::reset_mode,
"Cache reset mode")
.def("__eq__", &CacheConfig::operator==, py::is_operator(),
"Check if two CacheConfig objects are equal")
.def("__ne__", &CacheConfig::operator!=, py::is_operator(),
"Check if two CacheConfig objects are not equal")
.def("__repr__", [](const CacheConfig &cfg) {
return fmt::format("CacheConfig(type={}, num_layers={}, max_kv_cache_length={}, "
"initial_capacity={}, initial_batch_size={}, growth_factor={}, "
"allow_expand={}, reset_mode={})",
static_cast<int>(cfg.type), cfg.num_layers,
cfg.max_kv_cache_length, cfg.initial_capacity,
cfg.initial_batch_size, cfg.growth_factor,
cfg.allow_expand, static_cast<int>(cfg.reset_mode));
});
}
} // namespace infinilm::cache
namespace infinilm::engine::distributed {
inline void bind_dist_config(py::module &m) {
......@@ -86,19 +31,24 @@ namespace infinilm::engine {
inline void bind_infer_engine(py::module &m) {
py::class_<InferEngine, std::shared_ptr<InferEngine>> infer_engine(m, "InferEngine");
infer_engine
.def(py::init([](const InfinilmModel::Config &cfg,
const infinilm::engine::distributed::DistConfig &dist,
infinicore::Device::Type dev,
const infinilm::cache::CacheConfig &cache_config) {
return new InferEngine(cfg, dist, dev, cache_config);
.def(py::init([](
const InfinilmModel::Config &cfg,
const distributed::DistConfig &dist,
infinicore::Device::Type dev,
std::shared_ptr<const infinilm::cache::CacheConfig> cache_cfg) {
return std::make_shared<InferEngine>(
cfg,
dist,
dev,
cache_cfg ? cache_cfg.get() : nullptr);
}),
py::arg("config"),
py::arg("distributed_config") = distributed::DistConfig(),
py::arg("device_type") = infinicore::context::getDevice().getType(),
py::arg("cache_config") = infinilm::cache::CacheConfig())
.def("load_param", &InferEngine::load_param,
py::arg("name"), py::arg("param"),
"Load a parameter tensor into all workers (each worker picks its shard)")
py::arg("cache_config") = py::none())
.def("load_param", &InferEngine::load_param,
py::arg("name"), py::arg("param"),
"Load a parameter tensor into all workers (each worker picks its shard)")
.def("state_dict", [](InferEngine &self) {
py::list state_dict_tp_all;
for (const auto &state_dict_tp : self.state_dict()) {
......@@ -110,19 +60,26 @@ inline void bind_infer_engine(py::module &m) {
}
return state_dict_tp_all;
})
.def("forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
.def("reset_cache", py::overload_cast<size_t>(&InferEngine::reset_cache), py::arg("pos") = 0, "Reset the internal cache in all workers to a specific position")
.def("reset_cache", py::overload_cast<const cache::CacheConfig &, size_t>(&InferEngine::reset_cache), py::arg("cache_config"), py::arg("pos") = 0, "Reset cache with new KV configuration")
.def("get_cache_config", &InferEngine::get_cache_config, "Get current KV configuration")
.def(
"forward", [](InferEngine &self, const InferEngine::Input &input) -> InferEngine::Output { return self.forward(input); }, "Run inference on all ranks with arbitrary arguments")
.def(
"reset_cache", [](InferEngine &self, std::shared_ptr<const cache::CacheConfig> cfg) {
self.reset_cache(cfg ? cfg.get() : nullptr);
},
py::arg("cache_config") = py::none())
.def("get_cache_config", [](const InferEngine &self) {
auto cfg = self.get_cache_config();
return std::shared_ptr<cache::CacheConfig>(std::move(cfg->unique_copy()));
})
.def("__repr__", [](const InferEngine &self) {
return "<InferEngine: " + std::string(self.get_dist_config()) + ">";
});
py::class_<InferEngine::Input>(infer_engine, "Input")
.def(py::init([](const infinicore::Tensor &input_ids, const infinicore::Tensor &position_ids) {
return new InferEngine::Input{input_ids, position_ids};
.def(py::init([](const infinicore::Tensor &input_ids, const infinicore::Tensor &position_ids, const infinicore::Tensor &cache_positions) {
return new InferEngine::Input{input_ids, position_ids, cache_positions};
}),
py::arg("input_ids"), py::arg("position_ids"));
py::arg("input_ids"), py::arg("position_ids"), py::arg("cache_positions"));
py::class_<InferEngine::Output>(infer_engine, "Output")
.def_readwrite("logits", &InferEngine::Output::logits, "Output tensor");
......
......@@ -3,22 +3,23 @@
#include <cstring>
#include <iostream>
#include <spdlog/spdlog.h>
#include <stdio.h>
#include <stdlib.h>
inline void assertTrue(int expr, const char *msg, const char *file, int line) {
inline void assertTrue(int expr, const char *msg, const char *function, const char *file, int line) {
if (!expr) {
fprintf(stderr, "\033[31mAssertion failed:\033[0m %s at file %s, line %d\n", msg, file, line);
exit(EXIT_FAILURE);
spdlog::error("Assertion failed: {} in function {} at file {}, line {}", msg, function, file, line);
throw std::runtime_error("Assertion failed");
}
}
#define ASSERT(expr) assertTrue((expr), #expr " is false", __FILE__, __LINE__)
#define ASSERT_EQ(a, b) assertTrue((a) == (b), #a " != " #b, __FILE__, __LINE__)
#define ASSERT_VALID_PTR(a) assertTrue((a) != nullptr, #a " is nullptr", __FILE__, __LINE__)
#define ASSERT(expr) assertTrue((expr), #expr " is false", __func__, __FILE__, __LINE__)
#define ASSERT_EQ(a, b) assertTrue((a) == (b), #a " != " #b, __func__, __FILE__, __LINE__)
#define ASSERT_VALID_PTR(a) assertTrue((a) != nullptr, #a " is nullptr", __func__, __FILE__, __LINE__)
#define PANIC(EXPR) \
printf("Error at %s:%d - %s\n", __FILE__, __LINE__, #EXPR); \
#define PANIC(EXPR) \
spdlog::error("Error at {} in function {} at file {}, line {}", #EXPR, function, file, line); \
exit(EXIT_FAILURE)
#define RUN_INFINI(API) \
......@@ -28,7 +29,7 @@ inline void assertTrue(int expr, const char *msg, const char *file, int line) {
std::cerr << "Error Code " << api_result_ << " in `" << #API << "`" \
<< " from " << __func__ \
<< " at " << __FILE__ << ":" << __LINE__ << std::endl; \
exit(EXIT_FAILURE); \
throw std::runtime_error("InfiniCore C API Error"); \
} \
} while (0)
......
......@@ -317,7 +317,7 @@ if __name__ == "__main__":
# reset cache for each case
initial_capacity = input_len + output_len + 100
test.model.reset_cache(
batch_size=batch_size, pos=0, initial_capacity=initial_capacity
batch_size=batch_size, initial_capacity=initial_capacity
)
# run test one case
......
......@@ -141,14 +141,21 @@ def test(
)
for prompt in prompts
]
print(input_contents[0], end="", flush=True)
input_ids_list = tokenizer.batch_encode_plus(input_contents)[
"input_ids"
] # List: [[1, 1128, 526, 366, 29892]]
# 根据输入长度和最长输出长度创建KVCache
model.reset_cache(
1 if prompts is str else len(prompts),
max_new_tokens + len(input_ids_list[0]),
)
# ---------------------------------------------------------------------------- #
# 自回归生成
# ---------------------------------------------------------------------------- #
print(input_contents[0], end="", flush=True)
input_ids_infini = infinicore.from_list(input_ids_list)
t1 = time.time()
......
from .models import AutoLlamaModel
from . import distributed
from . import cache
__all__ = ["AutoLlamaModel", "distributed"]
__all__ = ["AutoLlamaModel", "distributed", "cache"]
from .cache import CacheConfig, StaticKVCacheConfig
__all__ = ["CacheConfig", "StaticKVCacheConfig"]
from infinilm.lib import _infinilm
class CacheConfig(_infinilm.CacheConfig):
def __init__(self):
raise NotImplementedError(
"CacheConfig is an abstract class. Please use a subclass of CacheConfig."
)
class StaticKVCacheConfig(CacheConfig, _infinilm.StaticKVCacheConfig):
def __init__(self, max_batch_size: int = 1, max_cache_len: int = 4096):
_infinilm.StaticKVCacheConfig.__init__(self, max_batch_size, max_cache_len)
......@@ -69,16 +69,18 @@ class GenerationMixin:
model_inputs["past_key_values"] = past_key_values
# -------------------------------------------------------------------------- #
# 计算所需的position_ids
# 计算所需的: position_ids
# -------------------------------------------------------------------------- #
current_position_ids = kwargs.get("position_ids", None)
if current_position_ids is None:
# prill阶段
bs, seq_len = kwargs["input_ids"].shape[0:2]
model_inputs["position_ids"] = self._get_initial_position_ids(bs, seq_len)
model_inputs["cache_positions"] = infinicore.from_list(
[0], dtype=infinicore.int64
)
else:
# decoder 阶段
# decode 阶段
bs, seq_len = current_position_ids.shape
last_position = current_position_ids.narrow(1, seq_len - 1, 1)
......@@ -90,7 +92,13 @@ class GenerationMixin:
next_position = one_value + last_position
model_inputs["position_ids"] = next_position
model_inputs["cache_positions"] = kwargs[
"cache_positions"
] + infinicore.from_list(
[seq_len],
dtype=last_position.dtype,
device=last_position.device,
)
# -------------------------------------------------------------------- #
# 所需的: token的input_ids
# -------------------------------------------------------------------- #
......@@ -119,29 +127,14 @@ class GenerationMixin:
):
model_kwargs = kwargs
# -------------------------------------------------------------------- #
# CRITICAL: Reset internal cache before each new generation
# This prevents state from persisting between different questions/prompts
# -------------------------------------------------------------------- #
# Check if this is a cpp backend model (has _model attribute with reset_cache method)
if hasattr(self, "_model") and hasattr(self._model, "reset_cache"):
try:
self._model.reset_cache()
except Exception as e:
# If reset_cache fails, log but continue (shouldn't happen)
import warnings
warnings.warn(f"Failed to reset cache: {e}")
# -------------------------------------------------------------------- #
# 创建 cache #
# -------------------------------------------------------------------- #
if self.use_cache:
model_kwargs["use_cache"] = True
model_kwargs["past_key_values"] = DynamicCache(config=self.config)
else:
model_kwargs["use_cache"] = False
model_kwargs["past_key_values"] = None
if not (hasattr(self, "_model") and hasattr(self._model, "reset_cache")):
if self.use_cache:
model_kwargs["use_cache"] = True
model_kwargs["past_key_values"] = DynamicCache(config=self.config)
else:
model_kwargs["use_cache"] = False
model_kwargs["past_key_values"] = None
# -------------------------------------------------------------------- #
# _sample函数 #
......@@ -204,6 +197,7 @@ class GenerationMixin:
model_inputs = self.prepare_inputs_for_generation(**model_kwargs)
model_kwargs["position_ids"] = model_inputs["position_ids"]
model_kwargs["cache_positions"] = model_inputs["cache_positions"]
# -------------------------------------------------------------------------- #
# 计算一次
......
......@@ -2,11 +2,11 @@ from ....generation.utils import GenerationMixin
import infinicore
from infinilm.models.llama.configuration_llama import LlamaConfig
from infinilm.lib import _infinilm
from infinilm.cache import StaticKVCacheConfig
from infinilm.distributed import DistConfig
import json
import os
from typing import Optional, Union
from collections import OrderedDict
class LlamaForCausalLM(GenerationMixin):
......@@ -18,6 +18,7 @@ class LlamaForCausalLM(GenerationMixin):
device=None,
dtype=None,
distributed_config=DistConfig(1),
cache_config=None,
):
"""
Create LlamaForCausalLM
......@@ -51,18 +52,19 @@ class LlamaForCausalLM(GenerationMixin):
# config._underlying, device._underlying, dtype
# )
self._model = _infinilm.InferEngine(
config, distributed_config._underlying, device._underlying.type
config,
distributed_config._underlying,
device._underlying.type,
cache_config,
)
def reset_cache(self, batch_size: int, pos: int = 0, initial_capacity: int = 1024):
def reset_cache(self, batch_size: int, initial_capacity: int = 1024):
"""Reset the cache for the model"""
infinicore.sync_device()
cache_config = self._model.get_cache_config()
cache_config.initial_batch_size = batch_size
cache_config.initial_capacity = initial_capacity
cache_config = StaticKVCacheConfig(batch_size, initial_capacity)
self._model.reset_cache(cache_config, pos)
self._model.reset_cache(cache_config)
def state_dict_keyname(self):
"""Get model key name."""
......@@ -102,19 +104,25 @@ class LlamaForCausalLM(GenerationMixin):
# like get_text_config() used by DynamicCache
return self._config
def forward(self, input_ids, position_ids, *args, **kwargs):
kv_caches = None
# return infinicore.Tensor(
# self._model.forward(input_ids, position_ids, kv_caches)
# )
def forward(self, input_ids, position_ids, cache_positions, *args, **kwargs):
return infinicore.Tensor(
self._model.forward(
self._model.Input(input_ids._underlying, position_ids._underlying)
self._model.Input(
input_ids._underlying,
position_ids._underlying,
cache_positions._underlying,
)
).logits
)
def __call__(self, input_ids, position_ids, *args, **kwargs):
return self.forward(input_ids=input_ids, position_ids=position_ids)
def __call__(self, input_ids, position_ids, cache_positions, *args, **kwargs):
return self.forward(
input_ids=input_ids,
position_ids=position_ids,
cache_positions=cache_positions,
*args,
**kwargs,
)
@classmethod
def from_pretrained(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment