Unverified Commit 8c224092 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #108 from InfiniTensor/issue/92-b

issue/92 将run改为异步
parents 3d328d61 9c4020a4
......@@ -39,6 +39,10 @@ infinicore::Tensor InferEngine::generate(const infinicore::Tensor &input_ids,
for (auto &worker : workers_) {
worker->run(std::vector<std::any>({input_ids, position_ids}));
}
// Wait for all workers
for (auto &worker : workers_) {
worker->wait();
}
return workers_[0]->get_output();
}
......
......@@ -75,28 +75,32 @@ void RankWorker::load_param(const std::string &name,
}
//------------------------------------------------------
// run -- synchronous (blocks until worker finishes forward)
// run -- asynchronous
//------------------------------------------------------
void RankWorker::run(const std::vector<std::any> &args) {
{
std::lock_guard<std::mutex> lock(mutex_);
if (should_exit_) {
throw std::runtime_error("RankWorker is closing; cannot run");
}
std::lock_guard<std::mutex> lock(mutex_);
pending_args_ = args;
job_cmd_ = Command::RUN;
has_job_ = true;
job_done_ = false;
if (should_exit_) {
throw std::runtime_error("RankWorker is closing; cannot run");
}
pending_args_ = args;
job_cmd_ = Command::RUN;
has_job_ = true;
job_done_ = false;
cv_.notify_all();
}
// Wait for job completion
//------------------------------------------------------
// wait -- asynchronous
//------------------------------------------------------
void RankWorker::wait() {
std::unique_lock<std::mutex> lk(mutex_);
cv_.wait(lk, [&] { return job_done_ || should_exit_; });
if (should_exit_) {
throw std::runtime_error("RankWorker stopped while running");
throw std::runtime_error("RankWorker stopped during run");
}
}
......
......@@ -28,10 +28,12 @@ public:
void load_param(const std::string &name,
const infinicore::Tensor &param);
// Submit a run (forward) job and wait until completion.
// The result can be retrieved with get_output().
// Submit a run (forward) job.
void run(const std::vector<std::any> &args);
// Wait until run job completes. The result can be retrieved with get_output().
void wait();
// Request worker shutdown and join the thread.
void close();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment