Commit 2c925eb4 authored by PanZezhong's avatar PanZezhong Committed by wooway777
Browse files

issue/143 add barrier for compilers

parent 429f54cd
#include "general_compiler.hpp"
namespace infinilm::engine {
GeneralCompiler::GeneralCompiler(const std::shared_ptr<InfinilmModel> &model) : GraphCompiler(model) {
static_batching_compiler_ = std::make_unique<StaticBatchingCompiler>(model_);
paged_compiler_ = std::make_unique<PagedCompiler>(model_);
GeneralCompiler::GeneralCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier) : GraphCompiler(model, barrier) {
static_batching_compiler_ = std::make_unique<StaticBatchingCompiler>(model_, barrier);
paged_compiler_ = std::make_unique<PagedCompiler>(model_, barrier);
}
void GeneralCompiler::compile() {
......
......@@ -6,7 +6,7 @@
namespace infinilm::engine {
class GeneralCompiler : public GraphCompiler {
public:
GeneralCompiler(const std::shared_ptr<InfinilmModel> &model);
GeneralCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier);
void compile() override;
......
#pragma once
#include "../../models/infinilm_model.hpp"
#include "../rank_barrier.hpp"
namespace infinilm::engine {
......@@ -10,7 +11,7 @@ public:
std::shared_ptr<infinicore::graph::Graph>,
std::shared_ptr<InfinilmModel::Output>>;
explicit GraphCompiler(const std::shared_ptr<InfinilmModel> &model) : model_(model) {}
explicit GraphCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier) : model_(model), barrier_(barrier) {}
virtual ~GraphCompiler() = default;
virtual void compile() = 0;
......@@ -18,6 +19,7 @@ public:
protected:
std::shared_ptr<InfinilmModel> model_;
RankBarrier *barrier_;
};
} // namespace infinilm::engine
#include "paged_compiler.hpp"
namespace infinilm::engine {
PagedCompiler::PagedCompiler(const std::shared_ptr<InfinilmModel> &model)
: GraphCompiler(model) {
PagedCompiler::PagedCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier)
: GraphCompiler(model, barrier) {
for (size_t b = 1; b < 32; b++) {
decode_batch_sizes_.push_back(b);
}
......@@ -43,9 +43,12 @@ void PagedCompiler::compile() {
infinicore::context::memcpyH2D(input.input_offsets.value()->data(), input_offsets_vec.data(), (b + 1) * sizeof(int64_t), false);
input.block_tables = block_tables_holder_->as_strided({b, block_per_req}, {(ptrdiff_t)block_per_req, 1});
input.slot_mapping = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
barrier_->wait();
infinicore::context::startGraphRecording();
auto output = model_->forward(input);
auto graph = infinicore::context::stopGraphRecording();
barrier_->wait();
auto shared_output = std::shared_ptr<InfinilmModel::Output>(
new InfinilmModel::Output{infinicore::graph::GraphTensor(output.logits)});
......
......@@ -7,7 +7,7 @@
namespace infinilm::engine {
class PagedCompiler : public GraphCompiler {
public:
PagedCompiler(const std::shared_ptr<InfinilmModel> &model);
PagedCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier);
void compile() override;
......
......@@ -3,8 +3,8 @@
#include "../../cache/cache.hpp"
namespace infinilm::engine {
StaticBatchingCompiler::StaticBatchingCompiler(const std::shared_ptr<InfinilmModel> &model)
: GraphCompiler(model) {
StaticBatchingCompiler::StaticBatchingCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier)
: GraphCompiler(model, barrier) {
}
void StaticBatchingCompiler::compile() {
......@@ -17,9 +17,12 @@ void StaticBatchingCompiler::compile() {
input.total_sequence_lengths = infinicore::Tensor::empty({b}, infinicore::DataType::I64, infinicore::context::getDevice());
std::vector<int64_t> total_sequence_lengths_vec(b, 1);
infinicore::context::memcpyH2D(input.total_sequence_lengths.value()->data(), total_sequence_lengths_vec.data(), b * sizeof(int64_t), false);
barrier_->wait();
infinicore::context::startGraphRecording();
auto output = model_->forward(input);
auto graph = infinicore::context::stopGraphRecording();
barrier_->wait();
auto shared_output = std::shared_ptr<InfinilmModel::Output>(new InfinilmModel::Output{infinicore::graph::GraphTensor(output.logits)});
......
......@@ -7,7 +7,7 @@
namespace infinilm::engine {
class StaticBatchingCompiler : public GraphCompiler {
public:
StaticBatchingCompiler(const std::shared_ptr<InfinilmModel> &model);
StaticBatchingCompiler(const std::shared_ptr<InfinilmModel> &model, RankBarrier *barrier);
void compile() override;
......
......@@ -20,12 +20,14 @@ InferEngine::InferEngine(
}
// Create one RankWorker per rank
int world_size = communication_group_.get_world_size();
barrier_ = std::make_unique<RankBarrier>((size_t)world_size);
workers_.reserve(world_size);
for (int r = 0; r < world_size; ++r) {
workers_.emplace_back(std::make_unique<RankWorker>(
model_config_,
communication_group_.get_rank_info(r),
cache_config_ != nullptr ? cache_config_.get() : nullptr,
barrier_.get(),
enable_graph_compiling));
}
}
......@@ -67,9 +69,9 @@ InferEngine::Input::to_model_input(infinicore::Device device) const {
};
return {
input_ids, // @todo: on device in the future
to_device(input_ids), // @todo: on device in the future
to_device(position_ids),
past_sequence_lengths, // @todo: on device in the future
to_device(past_sequence_lengths), // @todo: on device in the future
to_device(total_sequence_lengths),
to_device(input_offsets),
to_device(block_tables),
......@@ -90,6 +92,16 @@ InferEngine::Output InferEngine::forward(const InferEngine::Input &input) {
return workers_[0]->get_output();
}
void InferEngine::compile() {
for (auto &worker : workers_) {
worker->compile();
}
// Wait for all workers
for (auto &worker : workers_) {
worker->wait();
}
}
//------------------------------------------------------
// Destructor
//------------------------------------------------------
......@@ -114,6 +126,8 @@ void InferEngine::reset_cache(const cache::CacheConfig *new_config) {
for (auto &worker : workers_) {
worker->wait();
}
this->compile();
}
} // namespace infinilm::engine
......@@ -4,6 +4,7 @@
#include "../models/llama/llama_config.hpp"
#include "distributed/distributed.hpp"
#include "infinicore/tensor.hpp"
#include "rank_barrier.hpp"
#include "rank_worker.hpp"
#include <optional>
......@@ -34,6 +35,8 @@ public:
// Run a single forward pass on all workers and return the outputs from all ranks
Output forward(const Input &input);
void compile();
void reset_cache(const cache::CacheConfig *new_config);
~InferEngine();
......@@ -45,6 +48,7 @@ public:
protected:
std::vector<std::unique_ptr<RankWorker>> workers_;
std::unique_ptr<RankBarrier> barrier_;
distributed::CommunicationGroup communication_group_;
const InfinilmModel::Config &model_config_;
std::unique_ptr<cache::CacheConfig> cache_config_;
......
#include "rank_barrier.hpp"
namespace infinilm::engine {
RankBarrier::RankBarrier(size_t num_ranks) : thread_count_(num_ranks), generation_(0), arrived_(0) {}
void RankBarrier::wait() {
std::unique_lock<std::mutex> lock(mutex_);
int gen = generation_;
if (++arrived_ == thread_count_) {
// last thread
generation_++;
arrived_ = 0;
cv_.notify_all();
} else {
cv_.wait(lock, [&] { return gen != generation_; });
}
}
} // namespace infinilm::engine
#pragma once
#include <condition_variable>
#include <mutex>
namespace infinilm::engine {
class RankBarrier {
public:
explicit RankBarrier(size_t nranks);
void wait();
private:
const size_t thread_count_;
size_t arrived_;
size_t generation_;
std::mutex mutex_;
std::condition_variable cv_;
};
} // namespace infinilm::engine
......@@ -13,6 +13,7 @@ namespace infinilm::engine {
RankWorker::RankWorker(const InfinilmModel::Config &model_config,
const distributed::RankInfo &rank_info,
const cache::CacheConfig *cache_config,
RankBarrier *barrier,
bool enable_graph_compiling)
: model_config_(model_config),
rank_info_(rank_info),
......@@ -22,7 +23,8 @@ RankWorker::RankWorker(const InfinilmModel::Config &model_config,
job_done_(false),
should_exit_(false),
init_done_(false),
rng_(std::random_device{}()) {
rng_(std::random_device{}()),
barrier_(barrier) {
if (cache_config != nullptr) {
pending_cache_config_ = cache_config->unique_copy();
}
......@@ -115,6 +117,21 @@ void RankWorker::run(const Input &args) {
cv_.notify_all();
}
//------------------------------------------------------
// compile -- asynchronous
//------------------------------------------------------
void RankWorker::compile() {
std::lock_guard<std::mutex> lock(mutex_);
if (should_exit_) {
throw std::runtime_error("RankWorker is closing; cannot run");
}
job_cmd_ = Command::COMPILE;
has_job_ = true;
job_done_ = false;
cv_.notify_all();
}
//------------------------------------------------------
// wait -- asynchronous
//------------------------------------------------------
......@@ -183,8 +200,7 @@ void RankWorker::thread_loop() {
throw std::runtime_error("Failed to create model");
}
if (enable_graph_compiling_) {
compiler_ = std::make_unique<GeneralCompiler>(model_);
compiler_->compile();
compiler_ = std::make_unique<GeneralCompiler>(model_, barrier_);
}
init_done_ = true;
......@@ -315,10 +331,25 @@ void RankWorker::thread_loop() {
} else if (local_cmd == Command::RESET_CACHE) {
try {
model_->reset_cache(local_cache_config != nullptr ? local_cache_config.get() : nullptr);
{
std::lock_guard<std::mutex> lk(mutex_);
job_done_ = true;
}
cv_.notify_all();
} catch (const std::exception &e) {
std::lock_guard<std::mutex> lk(mutex_);
should_exit_ = true;
job_done_ = true;
cv_.notify_all();
spdlog::error("[{}] exception during reset_cache: {}\n", info(), e.what());
break;
}
} else if (local_cmd == Command::COMPILE) {
try {
if (compiler_ != nullptr) {
compiler_->compile();
}
{
std::lock_guard<std::mutex> lk(mutex_);
job_done_ = true;
......@@ -330,9 +361,10 @@ void RankWorker::thread_loop() {
should_exit_ = true;
job_done_ = true;
cv_.notify_all();
spdlog::error("[{}] exception during reset_cache: {}\n", info(), e.what());
spdlog::error("[{}] exception during compile: {}\n", info(), e.what());
break;
}
} else {
// Shouldn't reach here (no-op)
}
......
......@@ -4,6 +4,7 @@
#include "../models/model_factory.hpp"
#include "compiler/general_compiler.hpp"
#include "distributed/distributed.hpp"
#include "rank_barrier.hpp"
#include <any>
#include <condition_variable>
......@@ -21,6 +22,7 @@ class RankWorker {
LOAD,
RUN,
RESET_CACHE,
COMPILE,
STOP
};
......@@ -57,6 +59,7 @@ public:
RankWorker(const InfinilmModel::Config &model_config,
const distributed::RankInfo &rank_info,
const cache::CacheConfig *cache_config,
RankBarrier *barrier,
bool enable_graph_compiling);
// Submit a parameter load job and wait until the load completes on the worker thread.
......@@ -72,6 +75,9 @@ public:
// Reset the internal cache with a new configuration
void reset_cache(const cache::CacheConfig *new_config);
// Compile the model graph if enabled.
void compile();
// Wait until run job completes. The result can be retrieved with get_output().
void wait();
......@@ -122,6 +128,8 @@ private:
// Random
std::mt19937 rng_;
RankBarrier *barrier_;
};
} // namespace infinilm::engine
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment