threading.h 1002 Bytes
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
#ifndef LIGHTGBM_UTILS_THREADING_H_
#define LIGHTGBM_UTILS_THREADING_H_

4
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37

#include <vector>
#include <functional>

namespace LightGBM {

class Threading {
public:

  template<typename INDEX_T>
  static inline void For(INDEX_T start, INDEX_T end, const std::function<void(int, INDEX_T, INDEX_T)>& inner_fun) {
    int num_threads = 1;
    #pragma omp parallel
    #pragma omp master
    {
      num_threads = omp_get_num_threads();
    }
    INDEX_T num_inner = (end - start + num_threads - 1) / num_threads;
    if (num_inner <= 0) { num_inner = 1; }
    #pragma omp parallel for schedule(static,1)
    for (int i = 0; i < num_threads; ++i) {
      INDEX_T inner_start = start + num_inner * i;
      INDEX_T inner_end = inner_start + num_inner;
      if (inner_end > end) { inner_end = end; }
      if (inner_start < end) {
        inner_fun(i, inner_start, inner_end);
      }
    }
  }
};

}   // namespace LightGBM

Guolin Ke's avatar
Guolin Ke committed
38
#endif   // LightGBM_UTILS_THREADING_H_