cputhreadpool.h 3.08 KB
Newer Older
zhouxiang's avatar
zhouxiang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
//
// Created by huangyuyang on 7/5/23.
//

#ifndef FASTLLCPUTHREADPOOL_H
#define FASTLLCPUTHREADPOOL_H

#include <mutex>
#include <queue>
#include <functional>
#include <future>
#include <thread>
#include <utility>
#include <vector>

namespace fastllm {
    template <typename T>
    class TaskQueue {
    private:
        std::queue <T> q;
        std::mutex locker;
    public:
        TaskQueue() {}

        ~TaskQueue() {}

        bool Empty() {
            std::unique_lock<std::mutex> lock(locker);
            return q.empty();
        }

        int Size() {
            std::unique_lock<std::mutex> lock(locker);
            return q.size();
        }

        void Push(T &t) {
            std::unique_lock<std::mutex> lock(locker);
            q.emplace(t);
        }

        bool Pop(T &t) {
            std::unique_lock<std::mutex> lock(locker);
            if (q.empty()) {
                return false;
            }
            t = std::move(q.front());
            q.pop();
            return true;
        }
    };

    class ThreadPool {
    private:
        class ThreadWorker
        {
        private:
            int id;
            ThreadPool *pool;
        public:
            ThreadWorker(ThreadPool *pool, const int id) : pool(pool), id(id) {}

            void operator()() {
                std::function<void()> func;
                bool dequeued;

                while (!pool->shutdown) {
                    {
                        std::unique_lock<std::mutex> lock(pool->locker);
                        if (pool->queue.Empty()) {
                            pool->cv.wait(lock);
                        }

                        dequeued = pool->queue.Pop(func);
                    }
                    if (dequeued) {
                        func();
                    }
                }
            }
        };

        bool shutdown = false;
        TaskQueue<std::function<void()>> queue;
        std::vector<std::thread> threads;
        std::mutex locker;
        std::condition_variable cv;
    public:
        ThreadPool(const int t = 4) : threads(std::vector<std::thread>(t)) {
            for (int i = 0; i < threads.size(); ++i) {
                threads[i] = std::thread(ThreadWorker(this, i));
            }
        }
        void Shutdown() {
            shutdown = true;
            cv.notify_all();
            for (int i = 0; i < threads.size(); ++i) {
                if (threads[i].joinable()) {
                    threads[i].join();
                }
            }
        }

        template<typename F, typename... Args>
        auto Submit(F &&f, Args &&...args) -> std::future<decltype(f(args...))> {
            std::function<decltype(f(args...))()> func = std::bind(std::forward<F>(f), std::forward<Args>(args)...);
            auto task_ptr = std::make_shared<std::packaged_task<decltype(f(args...))()>>(func);
            std::function<void()> warpper_func = [task_ptr]() {
                (*task_ptr)();
            };
            queue.Push(warpper_func);
            cv.notify_one();
            return task_ptr->get_future();
        }
    };
}

#endif //FASTLLCPUTHREADPOOL_H