array_args.h 4.81 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
#ifndef LIGHTGBM_UTILS_ARRAY_AGRS_H_
#define LIGHTGBM_UTILS_ARRAY_AGRS_H_

Guolin Ke's avatar
Guolin Ke committed
4
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
5

6
7
8
9
#include <algorithm>
#include <utility>
#include <vector>

Guolin Ke's avatar
Guolin Ke committed
10
11
12
13
14
15
16
namespace LightGBM {

/*!
* \brief Contains some operation for a array, e.g. ArgMax, TopK.
*/
template<typename VAL_T>
class ArrayArgs {
Nikita Titov's avatar
Nikita Titov committed
17
 public:
Guolin Ke's avatar
Guolin Ke committed
18
19
20
21
22
23
24
25
26
  inline static size_t ArgMaxMT(const std::vector<VAL_T>& array) {
    int num_threads = 1;
#pragma omp parallel
#pragma omp master
    {
      num_threads = omp_get_num_threads();
    }
    int step = std::max(1, (static_cast<int>(array.size()) + num_threads - 1) / num_threads);
    std::vector<size_t> arg_maxs(num_threads, 0);
27
    #pragma omp parallel for schedule(static, 1)
Guolin Ke's avatar
Guolin Ke committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    for (int i = 0; i < num_threads; ++i) {
      size_t start = step * i;
      if (start >= array.size()) { continue; }
      size_t end = std::min(array.size(), start + step);
      size_t arg_max = start;
      for (size_t j = start + 1; j < end; ++j) {
        if (array[j] > array[arg_max]) {
          arg_max = j;
        }
      }
      arg_maxs[i] = arg_max;
    }
    size_t ret = arg_maxs[0];
    for (int i = 1; i < num_threads; ++i) {
      if (array[arg_maxs[i]] > array[ret]) {
        ret = arg_maxs[i];
      }
    }
    return ret;
  }
Guolin Ke's avatar
Guolin Ke committed
48
  inline static size_t ArgMax(const std::vector<VAL_T>& array) {
Guolin Ke's avatar
Guolin Ke committed
49
    if (array.empty()) {
Guolin Ke's avatar
Guolin Ke committed
50
51
      return 0;
    }
Guolin Ke's avatar
Guolin Ke committed
52
    if (array.size() > 1024) {
Guolin Ke's avatar
Guolin Ke committed
53
54
55
56
57
58
59
      return ArgMaxMT(array);
    } else {
      size_t arg_max = 0;
      for (size_t i = 1; i < array.size(); ++i) {
        if (array[i] > array[arg_max]) {
          arg_max = i;
        }
Guolin Ke's avatar
Guolin Ke committed
60
      }
Guolin Ke's avatar
Guolin Ke committed
61
      return arg_max;
Guolin Ke's avatar
Guolin Ke committed
62
63
64
65
    }
  }

  inline static size_t ArgMin(const std::vector<VAL_T>& array) {
Guolin Ke's avatar
Guolin Ke committed
66
    if (array.empty()) {
Guolin Ke's avatar
Guolin Ke committed
67
68
      return 0;
    }
Guolin Ke's avatar
Guolin Ke committed
69
    size_t arg_min = 0;
Guolin Ke's avatar
Guolin Ke committed
70
    for (size_t i = 1; i < array.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
71
72
      if (array[i] < array[arg_min]) {
        arg_min = i;
Guolin Ke's avatar
Guolin Ke committed
73
74
      }
    }
Guolin Ke's avatar
Guolin Ke committed
75
    return arg_min;
Guolin Ke's avatar
Guolin Ke committed
76
77
78
79
80
81
  }

  inline static size_t ArgMax(const VAL_T* array, size_t n) {
    if (n <= 0) {
      return 0;
    }
Guolin Ke's avatar
Guolin Ke committed
82
    size_t arg_max = 0;
Guolin Ke's avatar
Guolin Ke committed
83
    for (size_t i = 1; i < n; ++i) {
Guolin Ke's avatar
Guolin Ke committed
84
85
      if (array[i] > array[arg_max]) {
        arg_max = i;
Guolin Ke's avatar
Guolin Ke committed
86
87
      }
    }
Guolin Ke's avatar
Guolin Ke committed
88
    return arg_max;
Guolin Ke's avatar
Guolin Ke committed
89
90
91
92
93
94
  }

  inline static size_t ArgMin(const VAL_T* array, size_t n) {
    if (n <= 0) {
      return 0;
    }
Guolin Ke's avatar
Guolin Ke committed
95
    size_t arg_min = 0;
Guolin Ke's avatar
Guolin Ke committed
96
    for (size_t i = 1; i < n; ++i) {
Guolin Ke's avatar
Guolin Ke committed
97
98
      if (array[i] < array[arg_min]) {
        arg_min = i;
Guolin Ke's avatar
Guolin Ke committed
99
100
      }
    }
Guolin Ke's avatar
Guolin Ke committed
101
    return arg_min;
Guolin Ke's avatar
Guolin Ke committed
102
103
  }

Guolin Ke's avatar
Guolin Ke committed
104
105
106
107
108
109
110
  inline static void Partition(std::vector<VAL_T>* arr, int start, int end, int* l, int* r) {
    int i = start - 1;
    int j = end - 1;
    int p = i;
    int q = j;
    if (start >= end) {
      return;
Guolin Ke's avatar
Guolin Ke committed
111
    }
Guolin Ke's avatar
Guolin Ke committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
    std::vector<VAL_T>& ref = *arr;
    VAL_T v = ref[end - 1];
    for (;;) {
      while (ref[++i] > v);
      while (v > ref[--j]) { if (j == start) { break; } }
      if (i >= j) { break; }
      std::swap(ref[i], ref[j]);
      if (ref[i] == v) { p++; std::swap(ref[p], ref[i]); }
      if (v == ref[j]) { q--; std::swap(ref[j], ref[q]); }
    }
    std::swap(ref[i], ref[end - 1]);
    j = i - 1;
    i = i + 1;
    for (int k = start; k <= p; k++, j--) { std::swap(ref[k], ref[j]); }
    for (int k = end - 2; k >= q; k--, i++) { std::swap(ref[i], ref[k]); }
    *l = j;
    *r = i;
129
  }
Guolin Ke's avatar
Guolin Ke committed
130

Guolin Ke's avatar
Guolin Ke committed
131
  // Note: k refer to index here. e.g. k=0 means get the max number.
Guolin Ke's avatar
Guolin Ke committed
132
133
  inline static int ArgMaxAtK(std::vector<VAL_T>* arr, int start, int end, int k) {
    if (start >= end - 1) {
Guolin Ke's avatar
Guolin Ke committed
134
135
      return start;
    }
Guolin Ke's avatar
Guolin Ke committed
136
137
138
    int l = start;
    int r = end - 1;
    Partition(arr, start, end, &l, &r);
Guolin Ke's avatar
Guolin Ke committed
139
140
    // if find or all elements are the same.
    if ((k > l && k < r) || (l == start - 1 && r == end - 1)) {
Guolin Ke's avatar
Guolin Ke committed
141
142
      return k;
    } else if (k <= l) {
Guolin Ke's avatar
Guolin Ke committed
143
      return ArgMaxAtK(arr, start, l + 1, k);
Guolin Ke's avatar
Guolin Ke committed
144
145
    } else {
      return ArgMaxAtK(arr, r, end, k);
Guolin Ke's avatar
Guolin Ke committed
146
147
148
    }
  }

Guolin Ke's avatar
Guolin Ke committed
149
  // Note: k is 1-based here. e.g. k=3 means get the top-3 numbers.
Guolin Ke's avatar
Guolin Ke committed
150
  inline static void MaxK(const std::vector<VAL_T>& array, int k, std::vector<VAL_T>* out) {
Guolin Ke's avatar
Guolin Ke committed
151
152
153
154
155
156
157
    out->clear();
    if (k <= 0) {
      return;
    }
    for (auto val : array) {
      out->push_back(val);
    }
Guolin Ke's avatar
Guolin Ke committed
158
    if (static_cast<size_t>(k) >= array.size()) {
Guolin Ke's avatar
Guolin Ke committed
159
160
      return;
    }
Guolin Ke's avatar
Guolin Ke committed
161
    ArgMaxAtK(out, 0, static_cast<int>(out->size()), k - 1);
Guolin Ke's avatar
Guolin Ke committed
162
163
164
    out->erase(out->begin() + k, out->end());
  }

Guolin Ke's avatar
Guolin Ke committed
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
  inline static void Assign(std::vector<VAL_T>* array, VAL_T t, size_t n) {
    array->resize(n);
    for (size_t i = 0; i < array->size(); ++i) {
      (*array)[i] = t;
    }
  }

  inline static bool CheckAllZero(const std::vector<VAL_T>& array) {
    for (size_t i = 0; i < array.size(); ++i) {
      if (array[i] != VAL_T(0)) {
        return false;
      }
    }
    return true;
  }

Guolin Ke's avatar
Guolin Ke committed
181
182
183
184
185
186
187
188
  inline static bool CheckAll(const std::vector<VAL_T>& array, VAL_T t) {
    for (size_t i = 0; i < array.size(); ++i) {
      if (array[i] != t) {
        return false;
      }
    }
    return true;
  }
Guolin Ke's avatar
Guolin Ke committed
189
190
191
192
};

}  // namespace LightGBM

Guolin Ke's avatar
Guolin Ke committed
193
#endif   // LightGBM_UTILS_ARRAY_AGRS_H_
Guolin Ke's avatar
Guolin Ke committed
194