"projects/vscode:/vscode.git/clone" did not exist on "80b9c37f8b1f4a83e7f4969a29d8bf8de984fa03"
Commit 9c4020a4 authored by PanZezhong's avatar PanZezhong
Browse files

issue/92 将run改为异步

parent 3d328d61
...@@ -39,6 +39,10 @@ infinicore::Tensor InferEngine::generate(const infinicore::Tensor &input_ids, ...@@ -39,6 +39,10 @@ infinicore::Tensor InferEngine::generate(const infinicore::Tensor &input_ids,
for (auto &worker : workers_) { for (auto &worker : workers_) {
worker->run(std::vector<std::any>({input_ids, position_ids})); worker->run(std::vector<std::any>({input_ids, position_ids}));
} }
// Wait for all workers
for (auto &worker : workers_) {
worker->wait();
}
return workers_[0]->get_output(); return workers_[0]->get_output();
} }
......
...@@ -75,28 +75,32 @@ void RankWorker::load_param(const std::string &name, ...@@ -75,28 +75,32 @@ void RankWorker::load_param(const std::string &name,
} }
//------------------------------------------------------ //------------------------------------------------------
// run -- synchronous (blocks until worker finishes forward) // run -- asynchronous
//------------------------------------------------------ //------------------------------------------------------
void RankWorker::run(const std::vector<std::any> &args) { void RankWorker::run(const std::vector<std::any> &args) {
{ std::lock_guard<std::mutex> lock(mutex_);
std::lock_guard<std::mutex> lock(mutex_);
if (should_exit_) {
throw std::runtime_error("RankWorker is closing; cannot run");
}
pending_args_ = args; if (should_exit_) {
job_cmd_ = Command::RUN; throw std::runtime_error("RankWorker is closing; cannot run");
has_job_ = true;
job_done_ = false;
} }
pending_args_ = args;
job_cmd_ = Command::RUN;
has_job_ = true;
job_done_ = false;
cv_.notify_all(); cv_.notify_all();
}
// Wait for job completion //------------------------------------------------------
// wait -- asynchronous
//------------------------------------------------------
void RankWorker::wait() {
std::unique_lock<std::mutex> lk(mutex_); std::unique_lock<std::mutex> lk(mutex_);
cv_.wait(lk, [&] { return job_done_ || should_exit_; }); cv_.wait(lk, [&] { return job_done_ || should_exit_; });
if (should_exit_) { if (should_exit_) {
throw std::runtime_error("RankWorker stopped while running"); throw std::runtime_error("RankWorker stopped during run");
} }
} }
......
...@@ -28,10 +28,12 @@ public: ...@@ -28,10 +28,12 @@ public:
void load_param(const std::string &name, void load_param(const std::string &name,
const infinicore::Tensor &param); const infinicore::Tensor &param);
// Submit a run (forward) job and wait until completion. // Submit a run (forward) job.
// The result can be retrieved with get_output().
void run(const std::vector<std::any> &args); void run(const std::vector<std::any> &args);
// Wait until run job completes. The result can be retrieved with get_output().
void wait();
// Request worker shutdown and join the thread. // Request worker shutdown and join the thread.
void close(); void close();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment