"docs/vscode:/vscode.git/clone" did not exist on "7fbab730bd8e91b85e3b2ee2defc9a6de2a09a7c"
backend.cpp 5.12 KB
Newer Older
chenxl's avatar
chenxl committed
1
2
3
4
5
/**
 * @Description  :
 * @Author       : chenht2022
 * @Date         : 2024-07-22 02:03:05
 * @Version      : 1.0.0
chenxl's avatar
chenxl committed
6
 * @LastEditors  : chenht2022
chenxl's avatar
chenxl committed
7
8
9
 * @LastEditTime : 2024-07-25 10:33:34
 * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
 **/
chenxl's avatar
chenxl committed
10

chenxl's avatar
chenxl committed
11
12
#include "backend.h"

liam's avatar
liam committed
13
14
15
16
17
18
19
#ifdef USE_NUMA
#include <numa.h>
#include <numaif.h>

thread_local int Backend::numa_node = -1;
#endif

chenxl's avatar
chenxl committed
20
21
22
23
24
25
thread_local int Backend::thread_local_id = -1;

Backend::Backend(int max_thread_num) {
    max_thread_num_ = max_thread_num;
    thread_state_.resize(max_thread_num_);
    for (int i = 0; i < max_thread_num_; i++) {
chenxl's avatar
chenxl committed
26
        thread_state_[i].curr = std::make_unique<std::atomic<int>>();
chenxl's avatar
chenxl committed
27
28
        thread_state_[i].status =
            std::make_unique<std::atomic<ThreadStatus>>(ThreadStatus::WAITING);
chenxl's avatar
chenxl committed
29
    }
chenxl's avatar
chenxl committed
30
31
    workers_.resize(max_thread_num_);
    for (int i = 1; i < max_thread_num_; i++) {
chenxl's avatar
chenxl committed
32
33
34
35
36
        workers_[i] = std::thread(&Backend::worker_thread, this, i);
    }
}

Backend::~Backend() {
chenxl's avatar
chenxl committed
37
38
39
    for (int i = 0; i < max_thread_num_; i++) {
        thread_state_[i].status->store(ThreadStatus::EXIT,
                                       std::memory_order_release);
chenxl's avatar
chenxl committed
40
    }
chenxl's avatar
chenxl committed
41
    for (int i = 1; i < max_thread_num_; i++) {
chenxl's avatar
chenxl committed
42
43
44
45
46
47
        if (workers_[i].joinable()) {
            workers_[i].join();
        }
    }
}

chenxl's avatar
chenxl committed
48
int Backend::get_thread_num() { return max_thread_num_; }
chenxl's avatar
chenxl committed
49

chenxl's avatar
chenxl committed
50
51
52
53
54
55
56
void Backend::do_work_stealing_job(int task_num,
                                   std::function<void(int)> init_func,
                                   std::function<void(int)> compute_func,
                                   std::function<void(int)> finalize_func) {
    init_func_ = init_func;
    compute_func_ = compute_func;
    finalize_func_ = finalize_func;
wkgcass's avatar
wkgcass committed
57
58
59
60
#ifdef USE_NUMA
    // numa node location will be calculated based on the number of threads
    thread_num_ = max_thread_num_;
#else
chenxl's avatar
chenxl committed
61
    thread_num_ = std::min(max_thread_num_, task_num);
wkgcass's avatar
wkgcass committed
62
#endif
chenxl's avatar
chenxl committed
63
64
65
    int base = task_num / thread_num_;
    int remain = task_num % thread_num_;
    thread_state_[0].end = base + (0 < remain);
chenxl's avatar
chenxl committed
66
67
68
69

    // 为主线程设置 thread_local_id
    thread_local_id = 0;

chenxl's avatar
chenxl committed
70
    for (int i = 1; i < thread_num_; i++) {
chenxl's avatar
chenxl committed
71
72
        thread_state_[i].curr->store(thread_state_[i - 1].end,
                                     std::memory_order_relaxed);
chenxl's avatar
chenxl committed
73
        thread_state_[i].end = thread_state_[i - 1].end + base + (i < remain);
chenxl's avatar
chenxl committed
74
75
        thread_state_[i].status->store(ThreadStatus::WORKING,
                                       std::memory_order_release);
chenxl's avatar
chenxl committed
76
77
    }
    thread_state_[0].curr->store(0, std::memory_order_relaxed);
chenxl's avatar
chenxl committed
78
79
    thread_state_[0].status->store(ThreadStatus::WORKING,
                                   std::memory_order_release);
chenxl's avatar
chenxl committed
80
81
    process_tasks(0);
    for (int i = 1; i < thread_num_; i++) {
chenxl's avatar
chenxl committed
82
83
        while (thread_state_[i].status->load(std::memory_order_acquire) ==
               ThreadStatus::WORKING) {
chenxl's avatar
chenxl committed
84
85
86
87
88
        }
    }
}

void Backend::process_tasks(int thread_id) {
liam's avatar
liam committed
89
90
91
92
93
94
95
96
97
98
    
    #ifdef USE_NUMA
    if(numa_node == -1){
        numa_node = thread_id * numa_num_configured_nodes() / thread_num_;
        struct bitmask* mask = numa_bitmask_alloc(numa_num_configured_nodes());
        numa_bitmask_setbit(mask, numa_node);
        numa_bind(mask);
    }
    #endif

chenxl's avatar
chenxl committed
99
100
101
    if (init_func_ != nullptr) {
        init_func_(thread_id);
    }
chenxl's avatar
chenxl committed
102
    while (true) {
chenxl's avatar
chenxl committed
103
104
        int task_id = thread_state_[thread_id].curr->fetch_add(
            1, std::memory_order_acq_rel);
chenxl's avatar
chenxl committed
105
106
107
        if (task_id >= thread_state_[thread_id].end) {
            break;
        }
chenxl's avatar
chenxl committed
108
        compute_func_(task_id);
chenxl's avatar
chenxl committed
109
110
111
    }
    for (int t_offset = 1; t_offset < thread_num_; t_offset++) {
        int t_i = (thread_id + t_offset) % thread_num_;
chenxl's avatar
chenxl committed
112
113
        if (thread_state_[t_i].status->load(std::memory_order_acquire) !=
            ThreadStatus::WORKING) {
chenxl's avatar
chenxl committed
114
115
116
            continue;
        }
        while (true) {
chenxl's avatar
chenxl committed
117
118
            int task_id = thread_state_[t_i].curr->fetch_add(
                1, std::memory_order_acq_rel);
chenxl's avatar
chenxl committed
119
120
121
            if (task_id >= thread_state_[t_i].end) {
                break;
            }
chenxl's avatar
chenxl committed
122
            compute_func_(task_id);
chenxl's avatar
chenxl committed
123
124
        }
    }
chenxl's avatar
chenxl committed
125
126
127
128
129
    if (finalize_func_ != nullptr) {
        finalize_func_(thread_id);
    }
    thread_state_[thread_id].status->store(ThreadStatus::WAITING,
                                           std::memory_order_release);
chenxl's avatar
chenxl committed
130
131
132
133
}

void Backend::worker_thread(int thread_id) {
    auto start = std::chrono::steady_clock::now();
chenxl's avatar
chenxl committed
134
    thread_local_id = thread_id; // 设置线程本地变量
chenxl's avatar
chenxl committed
135
    while (true) {
chenxl's avatar
chenxl committed
136
137
        ThreadStatus status =
            thread_state_[thread_id].status->load(std::memory_order_acquire);
chenxl's avatar
chenxl committed
138
139
140
141
142
        if (status == ThreadStatus::WORKING) {
            process_tasks(thread_id);
            start = std::chrono::steady_clock::now();
        } else if (status == ThreadStatus::WAITING) {
            auto now = std::chrono::steady_clock::now();
chenxl's avatar
chenxl committed
143
144
145
146
            auto duration =
                std::chrono::duration_cast<std::chrono::milliseconds>(now -
                                                                      start)
                    .count();
chenxl's avatar
chenxl committed
147
148
149
150
151
152
153
            if (duration > 50) {
                std::this_thread::sleep_for(std::chrono::milliseconds(1));
            }
        } else if (status == ThreadStatus::EXIT) {
            return;
        }
    }
154
}