// // Created by huangyuyang on 7/5/23. // #ifndef FASTLLCPUTHREADPOOL_H #define FASTLLCPUTHREADPOOL_H #include #include #include #include #include #include #include namespace fastllm { template class TaskQueue { private: std::queue q; std::mutex locker; public: TaskQueue() {} ~TaskQueue() {} bool Empty() { std::unique_lock lock(locker); return q.empty(); } int Size() { std::unique_lock lock(locker); return q.size(); } void Push(T &t) { std::unique_lock lock(locker); q.emplace(t); } bool Pop(T &t) { std::unique_lock lock(locker); if (q.empty()) { return false; } t = std::move(q.front()); q.pop(); return true; } }; class ThreadPool { private: class ThreadWorker { private: int id; ThreadPool *pool; public: ThreadWorker(ThreadPool *pool, const int id) : pool(pool), id(id) {} void operator()() { std::function func; bool dequeued; while (!pool->shutdown) { { std::unique_lock lock(pool->locker); if (pool->queue.Empty()) { pool->cv.wait(lock); } dequeued = pool->queue.Pop(func); } if (dequeued) { func(); } } } }; bool shutdown = false; TaskQueue> queue; std::vector threads; std::mutex locker; std::condition_variable cv; public: ThreadPool(const int t = 4) : threads(std::vector(t)) { for (int i = 0; i < threads.size(); ++i) { threads[i] = std::thread(ThreadWorker(this, i)); } } void Shutdown() { shutdown = true; cv.notify_all(); for (int i = 0; i < threads.size(); ++i) { if (threads[i].joinable()) { threads[i].join(); } } } template auto Submit(F &&f, Args &&...args) -> std::future { std::function func = std::bind(std::forward(f), std::forward(args)...); auto task_ptr = std::make_shared>(func); std::function warpper_func = [task_ptr]() { (*task_ptr)(); }; queue.Push(warpper_func); cv.notify_one(); return task_ptr->get_future(); } }; } #endif //FASTLLCPUTHREADPOOL_H