Commit 21274f33 authored by PanZezhong's avatar PanZezhong Committed by wooway777
Browse files

issue/143 feat: static and paged graph compilers

parent 96ecf490
......@@ -29,3 +29,5 @@ __pycache__/
*.txt
*.http
*.nsys-rep
#include "general_compiler.hpp"
namespace infinilm::engine {
GeneralCompiler::GeneralCompiler(const std::shared_ptr<InfinilmModel> &model) : GraphCompiler(model) {
static_batching_compiler_ = std::make_unique<StaticBatchingCompiler>(model_);
paged_compiler_ = std::make_unique<PagedCompiler>(model_);
}
void GeneralCompiler::compile() {
static_batching_compiler_->compile();
paged_compiler_->compile();
}
GeneralCompiler::Compiled GeneralCompiler::get_compiled(const InfinilmModel::Input &input) {
GeneralCompiler::Compiled result = {nullptr, nullptr};
// try each compiler, return the first valid result
result = static_batching_compiler_.get()->get_compiled(input);
if (std::get<0>(result) != nullptr && std::get<1>(result) != nullptr) {
return result;
}
result = paged_compiler_.get()->get_compiled(input);
return result;
}
} // namespace infinilm::engine
#pragma once
#include "paged_compiler.hpp"
#include "static_batching_compiler.hpp"
namespace infinilm::engine {
class GeneralCompiler : public GraphCompiler {
public:
GeneralCompiler(const std::shared_ptr<InfinilmModel> &model);
void compile() override;
Compiled get_compiled(const InfinilmModel::Input &input) override;
private:
std::unique_ptr<StaticBatchingCompiler> static_batching_compiler_;
std::unique_ptr<PagedCompiler> paged_compiler_;
};
} // namespace infinilm::engine
#pragma once
#include "../../models/infinilm_model.hpp"
namespace infinilm::engine {
class GraphCompiler {
public:
using Compiled = std::tuple<
std::shared_ptr<infinicore::graph::Graph>,
std::shared_ptr<InfinilmModel::Output>>;
explicit GraphCompiler(const std::shared_ptr<InfinilmModel> &model) : model_(model) {}
virtual ~GraphCompiler() = default;
virtual void compile() = 0;
virtual Compiled get_compiled(const InfinilmModel::Input &input) = 0;
protected:
std::shared_ptr<InfinilmModel> model_;
};
} // namespace infinilm::engine
#include "paged_compiler.hpp"
namespace infinilm::engine {
PagedCompiler::PagedCompiler(const std::shared_ptr<InfinilmModel> &model)
: GraphCompiler(model) {
for (size_t b = 1; b < 32; b++) {
decode_batch_sizes_.push_back(b);
}
for (size_t b = 32; b < 64; b += 8) {
decode_batch_sizes_.push_back(b);
}
for (size_t b = 64; b < 128; b += 16) {
decode_batch_sizes_.push_back(b);
}
for (size_t b = 128; b < 256; b += 32) {
decode_batch_sizes_.push_back(b);
}
for (size_t b = 256; b <= 512; b += 64) {
decode_batch_sizes_.push_back(b);
}
}
void PagedCompiler::compile() {
if (model_->get_cache_config() != nullptr && dynamic_cast<const cache::PagedKVCacheConfig *>(model_->get_cache_config())) {
size_t nblocks = dynamic_cast<const cache::PagedKVCacheConfig *>(model_->get_cache_config())->num_blocks();
size_t max_batch_size = *std::max_element(decode_batch_sizes_.begin(), decode_batch_sizes_.end());
compiled_map_decode_.clear();
block_tables_holder_ = infinicore::Tensor::empty(
{nblocks}, infinicore::DataType::I64, infinicore::context::getDevice());
for (size_t b : decode_batch_sizes_) {
size_t block_per_req = nblocks / b;
InfinilmModel::Input input;
input.input_ids = infinicore::Tensor::empty({1, b}, infinicore::DataType::I64, infinicore::context::getDevice());
input.position_ids = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
input.total_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
input.input_offsets = infinicore::Tensor::empty({b + 1}, infinicore::DataType::I64, infinicore::context::getDevice());
input.block_tables = block_tables_holder_->as_strided({b, block_per_req}, {(ptrdiff_t)block_per_req, 1});
input.slot_mapping = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
infinicore::context::startGraphRecording();
auto output = model_->forward(input);
auto graph = infinicore::context::stopGraphRecording();
auto shared_output = std::shared_ptr<InfinilmModel::Output>(
new InfinilmModel::Output{infinicore::graph::GraphTensor(output.logits)});
compiled_map_decode_[b] = CompiledResult{std::move(input), std::make_tuple(graph, shared_output)};
}
}
}
PagedCompiler::Compiled PagedCompiler::get_compiled(const InfinilmModel::Input &input) {
if (model_->get_cache_config() != nullptr && dynamic_cast<const cache::PagedKVCacheConfig *>(model_->get_cache_config())) {
size_t batch_size = input.block_tables.value()->size(0);
size_t block_per_req = input.block_tables.value()->size(1);
// only support decode only batch
if (batch_size != input.input_ids.value()->size(1)) {
return {nullptr, nullptr};
} else {
auto result = compiled_map_decode_.find(batch_size);
if (result == compiled_map_decode_.end()) {
return {nullptr, nullptr};
}
auto &graph_input = result->second.input;
graph_input.input_ids.value()->copy_from(input.input_ids.value());
graph_input.position_ids.value()->copy_from(input.position_ids.value());
graph_input.total_sequence_lengths.value()->copy_from(input.total_sequence_lengths.value());
graph_input.input_offsets.value()->copy_from(input.input_offsets.value());
graph_input.block_tables.value()->narrow({{1, 0, block_per_req}})->copy_from(input.block_tables.value());
graph_input.slot_mapping.value()->copy_from(input.slot_mapping.value());
auto graph = std::get<0>(result->second.compiled);
auto shared_output = std::shared_ptr<InfinilmModel::Output>(new InfinilmModel::Output{std::get<1>(result->second.compiled)->logits->resume_from_blob_()});
return std::make_tuple(graph, shared_output);
}
} else {
return {nullptr, nullptr};
}
}
} // namespace infinilm::engine
#pragma once
#include "graph_compiler.hpp"
#include <unordered_map>
namespace infinilm::engine {
class PagedCompiler : public GraphCompiler {
public:
PagedCompiler(const std::shared_ptr<InfinilmModel> &model);
void compile() override;
Compiled get_compiled(const InfinilmModel::Input &input) override;
private:
std::vector<size_t> decode_batch_sizes_;
infinicore::Tensor block_tables_holder_;
struct CompiledResult {
InfinilmModel::Input input;
Compiled compiled;
};
std::unordered_map<
size_t, // num_requests
CompiledResult>
compiled_map_decode_;
};
} // namespace infinilm::engine
#include "static_batching_compiler.hpp"
#include "../../cache/cache.hpp"
namespace infinilm::engine {
StaticBatchingCompiler::StaticBatchingCompiler(const std::shared_ptr<InfinilmModel> &model)
: GraphCompiler(model) {
}
void StaticBatchingCompiler::compile() {
if (model_->get_cache_config() != nullptr && dynamic_cast<const cache::StaticKVCacheConfig *>(model_->get_cache_config())) {
size_t b = dynamic_cast<const cache::StaticKVCacheConfig *>(model_->get_cache_config())->max_batch_size();
InfinilmModel::Input input;
input.input_ids = infinicore::Tensor::empty({b, 1}, infinicore::DataType::I64, infinicore::context::getDevice());
input.position_ids = infinicore::Tensor::empty({b, 1}, infinicore::DataType::I64, infinicore::context::getDevice());
input.past_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
input.total_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
infinicore::context::startGraphRecording();
auto output = model_->forward(input);
auto graph = infinicore::context::stopGraphRecording();
auto shared_output = std::shared_ptr<InfinilmModel::Output>(new InfinilmModel::Output{infinicore::graph::GraphTensor(output.logits)});
compiled_map_[std::make_tuple(b, 1)] = CompiledResult{std::move(input), std::make_tuple(graph, shared_output)};
}
}
StaticBatchingCompiler::Compiled StaticBatchingCompiler::get_compiled(
const InfinilmModel::Input &input) {
if (model_->get_cache_config() != nullptr && dynamic_cast<const cache::StaticKVCacheConfig *>(model_->get_cache_config())) {
size_t batch_size = input.input_ids.value()->size(0);
size_t seqlen = input.input_ids.value()->size(1);
auto result = compiled_map_.find(std::make_tuple(batch_size, seqlen));
if (result == compiled_map_.end()) {
return std::make_tuple(nullptr, nullptr);
} else {
auto &graph_input = result->second.input;
graph_input.input_ids.value()->copy_from(input.input_ids.value());
graph_input.position_ids.value()->copy_from(input.position_ids.value());
graph_input.past_sequence_lengths.value()->copy_from(input.past_sequence_lengths.value());
graph_input.total_sequence_lengths.value()->copy_from(input.total_sequence_lengths.value());
auto graph = std::get<0>(result->second.compiled);
auto shared_output = std::shared_ptr<InfinilmModel::Output>(new InfinilmModel::Output{std::get<1>(result->second.compiled)->logits->resume_from_blob_()});
return std::make_tuple(graph, shared_output);
}
} else {
return std::make_tuple(nullptr, nullptr);
}
}
} // namespace infinilm::engine
#pragma once
#include "graph_compiler.hpp"
#include <unordered_map>
namespace infinilm::engine {
class StaticBatchingCompiler : public GraphCompiler {
public:
StaticBatchingCompiler(const std::shared_ptr<InfinilmModel> &model);
void compile() override;
Compiled get_compiled(const InfinilmModel::Input &input) override;
private:
struct TupleHash {
size_t operator()(const std::tuple<size_t, size_t> &t) const noexcept {
auto h1 = std::hash<size_t>{}(std::get<0>(t));
auto h2 = std::hash<size_t>{}(std::get<1>(t));
return h1 ^ (h2 + 0x9e3779b97f4a7c15ULL + (h1 << 6) + (h1 >> 2));
}
};
struct CompiledResult {
InfinilmModel::Input input;
Compiled compiled;
};
std::unordered_map<
std::tuple<size_t, size_t>, // (batch_size, seq_len)
CompiledResult,
TupleHash>
compiled_map_;
};
} // namespace infinilm::engine
......@@ -10,7 +10,8 @@ InferEngine::InferEngine(
const InfinilmModel::Config &config,
const distributed::DistConfig &distributed_config,
infinicore::Device::Type device_type,
const cache::CacheConfig *cache_config) // Changed parameter
const cache::CacheConfig *cache_config,
bool enable_graph_compiling) // Changed parameter
: communication_group_(distributed_config, device_type),
model_config_(config) {
......@@ -24,7 +25,8 @@ InferEngine::InferEngine(
workers_.emplace_back(std::make_unique<RankWorker>(
model_config_,
communication_group_.get_rank_info(r),
cache_config_ != nullptr ? cache_config_.get() : nullptr));
cache_config_ != nullptr ? cache_config_.get() : nullptr,
enable_graph_compiling));
}
}
......
......@@ -22,7 +22,8 @@ public:
const InfinilmModel::Config &config,
const distributed::DistConfig &distributed_config = distributed::DistConfig(),
infinicore::Device::Type device_type = infinicore::context::getDevice().getType(),
const cache::CacheConfig *cache_config = nullptr);
const cache::CacheConfig *cache_config = nullptr,
bool enable_graph_compiling = false);
// Load a parameter to all workers (each can extract its shard inside RankWorker)
void load_param(const std::string &name, const infinicore::Tensor &param);
......
......@@ -12,9 +12,11 @@ namespace infinilm::engine {
RankWorker::RankWorker(const InfinilmModel::Config &model_config,
const distributed::RankInfo &rank_info,
const cache::CacheConfig *cache_config)
const cache::CacheConfig *cache_config,
bool enable_graph_compiling)
: model_config_(model_config),
rank_info_(rank_info),
enable_graph_compiling_(enable_graph_compiling),
job_cmd_(Command::INIT),
has_job_(false),
job_done_(false),
......@@ -180,6 +182,11 @@ void RankWorker::thread_loop() {
if (!model_) {
throw std::runtime_error("Failed to create model");
}
if (enable_graph_compiling_) {
compiler_ = std::make_unique<GeneralCompiler>(model_);
compiler_->compile();
}
init_done_ = true;
}
cv_.notify_all();
......@@ -245,9 +252,21 @@ void RankWorker::thread_loop() {
{
std::lock_guard<std::mutex> lk(mutex_);
auto model_args = local_args.to_model_input(rank_info_.device);
// Forward calculation
auto logits{model_->forward(model_args).logits};
infinicore::Tensor logits;
// Try to get compiled graph
if (compiler_ != nullptr) {
auto [graph, output] = compiler_->get_compiled(local_args.to_model_input(infinicore::Device::cpu()));
if (graph != nullptr && output != nullptr) {
graph->run();
logits = output->logits;
}
}
// Fall back to eager mode
if (!logits) {
auto model_args = local_args.to_model_input(rank_info_.device);
logits = model_->forward(model_args).logits;
}
// Random sampling (rank 0 only)
if (rank_info_.tp_rank == 0) {
auto temperature{local_args.temperature};
......@@ -296,6 +315,9 @@ void RankWorker::thread_loop() {
} else if (local_cmd == Command::RESET_CACHE) {
try {
model_->reset_cache(local_cache_config != nullptr ? local_cache_config.get() : nullptr);
if (compiler_ != nullptr) {
compiler_->compile();
}
{
std::lock_guard<std::mutex> lk(mutex_);
......
......@@ -2,6 +2,7 @@
#include "../cache/cache.hpp"
#include "../models/model_factory.hpp"
#include "compiler/general_compiler.hpp"
#include "distributed/distributed.hpp"
#include <any>
......@@ -55,7 +56,8 @@ public:
RankWorker(const InfinilmModel::Config &model_config,
const distributed::RankInfo &rank_info,
const cache::CacheConfig *cache_config);
const cache::CacheConfig *cache_config,
bool enable_graph_compiling);
// Submit a parameter load job and wait until the load completes on the worker thread.
void load_param(const std::string &name,
......@@ -91,6 +93,10 @@ private:
std::shared_ptr<InfinilmModel> model_;
std::shared_ptr<cache::Cache> cache_;
// Graph Compiling
bool enable_graph_compiling_;
std::unique_ptr<GraphCompiler> compiler_;
// Command for the pending job (protected by mutex_)
Command job_cmd_;
......
......@@ -43,5 +43,6 @@ public:
virtual Output forward(const Input &input) const = 0;
virtual void reset_cache(const cache::CacheConfig *cache_config) = 0;
virtual const cache::CacheConfig *get_cache_config() const = 0;
};
} // namespace infinilm
......@@ -45,7 +45,12 @@ LlamaForCausalLM::Output LlamaForCausalLM::forward(const Input &input) const {
}
void LlamaForCausalLM::reset_cache(const cache::CacheConfig *cache_config) {
model_->reset_cache(cache_config);
cache_config_ = cache_config->unique_copy();
model_->reset_cache(cache_config_.get());
}
const cache::CacheConfig *LlamaForCausalLM::get_cache_config() const {
return cache_config_.get();
}
} // namespace infinilm::models::llama
......@@ -42,6 +42,8 @@ public:
void reset_cache(const cache::CacheConfig *cache_config) override;
const cache::CacheConfig *get_cache_config() const override;
// Module information
const LlamaConfig &config() const { return model_->config(); }
LlamaModel &model() { return *model_; }
......@@ -53,6 +55,8 @@ protected:
// Language modeling head
INFINICORE_NN_MODULE(infinicore::nn::Linear, lm_head);
std::unique_ptr<cache::CacheConfig> cache_config_;
};
} // namespace infinilm::models::llama
......@@ -35,17 +35,20 @@ inline void bind_infer_engine(py::module &m) {
const InfinilmModel::Config &cfg,
const distributed::DistConfig &dist,
infinicore::Device::Type dev,
std::shared_ptr<const infinilm::cache::CacheConfig> cache_cfg) {
std::shared_ptr<const infinilm::cache::CacheConfig> cache_cfg,
bool enable_graph_compiling) {
return std::make_shared<InferEngine>(
cfg,
dist,
dev,
cache_cfg ? cache_cfg.get() : nullptr);
cache_cfg ? cache_cfg.get() : nullptr,
enable_graph_compiling);
}),
py::arg("config"),
py::arg("distributed_config") = distributed::DistConfig(),
py::arg("device_type") = infinicore::context::getDevice().getType(),
py::arg("cache_config") = py::none())
py::arg("cache_config") = py::none(),
py::arg("enable_graph_compiling") = false)
.def("load_param", &InferEngine::load_param,
py::arg("name"), py::arg("param"),
"Load a parameter tensor into all workers (each worker picks its shard)")
......
......@@ -3,7 +3,7 @@ from transformers import AutoTokenizer
from infinilm.modeling_utils import load_model_state_dict_by_file
from infinilm.distributed import DistConfig
from infinilm.infer_engine import GenerationConfig, InferEngine
from infinilm.cache import StaticKVCacheConfig
from infinilm.cache import StaticKVCacheConfig, PagedKVCacheConfig
import argparse
import sys
import time
......@@ -199,7 +199,16 @@ def get_args():
default=1.0,
help="sampling temperature",
)
parser.add_argument(
"--enable-paged-attn",
action="store_true",
help="use paged cache",
)
parser.add_argument(
"--enable-graph",
action="store_true",
help="enable graph compiling",
)
return parser.parse_args()
......@@ -223,6 +232,8 @@ class TestModel:
infini_device=infinicore.device("cpu", 0),
tp=1,
skip_load=False,
cache_config=None,
enable_graph=False,
) -> None:
model_path = os.path.expanduser(model_path)
# ---------------------------------------------------------------------------- #
......@@ -232,6 +243,8 @@ class TestModel:
model_path,
device=infini_device,
distributed_config=DistConfig(tp),
cache_config=cache_config,
enable_graph_compiling=enable_graph,
)
# ---------------------------------------------------------------------------- #
......@@ -336,6 +349,8 @@ if __name__ == "__main__":
batch_size = args.batch_size
input_len = args.input_len
output_len = args.output_len
enable_paged_attn = args.enable_paged_attn
enable_graph = args.enable_graph
if isinstance(batch_size, int):
batch_size = [batch_size]
......@@ -350,13 +365,25 @@ if __name__ == "__main__":
# -------------------------------------------------------- #
# 测试
# -------------------------------------------------------- #
# print("=================== start test ====================", type(batch_size))
if enable_paged_attn:
paged_kv_block_size = 16
max_num_blocks = max(
[
((c_["input_len"] + c_["output_len"] + 15) // 16) * c_["batch_size"]
for _, c_ in cases_dict.items()
]
)
cache_config = PagedKVCacheConfig(max_num_blocks, paged_kv_block_size)
else:
cache_config = None
test = TestModel(
model_path,
infini_device=infini_device,
tp=tp,
skip_load=skip_load,
cache_config=cache_config,
enable_graph=enable_graph,
)
for idx, case in tqdm(cases_dict.items(), desc="Processing cases"):
......@@ -366,13 +393,14 @@ if __name__ == "__main__":
input_len = case["input_len"]
output_len = case["output_len"]
# reset cache for each case
initial_capacity = input_len + output_len
test.model.reset_cache(
StaticKVCacheConfig(
max_batch_size=batch_size, max_cache_len=initial_capacity
if not enable_paged_attn:
# reset cache if static kvcache is used
initial_capacity = input_len + output_len
test.model.reset_cache(
StaticKVCacheConfig(
max_batch_size=batch_size, max_cache_len=initial_capacity
)
)
)
# run test one case
test.run(
......
......@@ -93,6 +93,11 @@ def get_args():
action="store_true",
help="use paged cache",
)
parser.add_argument(
"--enable-graph",
action="store_true",
help="enable graph compiling",
)
parser.add_argument(
"--top-k",
......@@ -125,6 +130,7 @@ def test(
infini_device=infinicore.device("cpu", 0),
tp=1,
enable_paged_attn=False,
enable_graph=False,
top_k=1,
top_p=1.0,
temperature=1.0,
......@@ -137,6 +143,7 @@ def test(
model_path,
device=infini_device,
distributed_config=DistConfig(tp),
enable_graph_compiling=enable_graph,
)
# ---------------------------------------------------------------------------- #
......@@ -193,7 +200,7 @@ def test(
batch_size = 1 if prompts is str else len(prompts)
max_total_tokens = max_new_tokens + len(input_ids_list[0])
cache_config = PagedKVCacheConfig(
num_blocks=(max_total_tokens // 16 + 1) * batch_size, block_size=16
num_blocks=((max_total_tokens + 15) // 16) * batch_size, block_size=16
)
else:
batch_size = 1 if prompts is str else len(prompts)
......@@ -265,6 +272,7 @@ if __name__ == "__main__":
backend = args.backend
tp = args.tp
enable_paged_attn = args.enable_paged_attn
enable_graph = args.enable_graph
if backend != "cpp":
raise ValueError(f"Unsupported backend: {backend}.")
......@@ -277,6 +285,7 @@ if __name__ == "__main__":
infini_device=infini_device,
tp=tp,
enable_paged_attn=enable_paged_attn,
enable_graph=enable_graph,
top_k=args.top_k,
top_p=args.top_p,
temperature=args.temperature,
......
......@@ -28,6 +28,7 @@ class InferEngine(_infinilm.InferEngine):
device=None,
distributed_config=DistConfig(1),
cache_config=None,
enable_graph_compiling=False,
):
self.config = AutoConfig.from_pretrained(model_path)
......@@ -39,6 +40,7 @@ class InferEngine(_infinilm.InferEngine):
distributed_config._underlying,
device._underlying.type,
cache_config,
enable_graph_compiling,
)
self.use_cache = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment