array_args.h 4.99 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
5
6
#ifndef LIGHTGBM_INCLUDE_LIGHTGBM_UTILS_ARRAY_ARGS_H_
#define LIGHTGBM_INCLUDE_LIGHTGBM_UTILS_ARRAY_ARGS_H_
Guolin Ke's avatar
Guolin Ke committed
7

8
9
10
#include <LightGBM/utils/openmp_wrapper.h>
#include <LightGBM/utils/threading.h>

11
12
13
14
#include <algorithm>
#include <utility>
#include <vector>

Guolin Ke's avatar
Guolin Ke committed
15
16
17
namespace LightGBM {

/*!
18
* \brief Contains some operation for an array, e.g. ArgMax, TopK.
Guolin Ke's avatar
Guolin Ke committed
19
20
21
*/
template<typename VAL_T>
class ArrayArgs {
Nikita Titov's avatar
Nikita Titov committed
22
 public:
Guolin Ke's avatar
Guolin Ke committed
23
  inline static size_t ArgMaxMT(const std::vector<VAL_T>& array) {
24
    int num_threads = OMP_NUM_THREADS();
Guolin Ke's avatar
Guolin Ke committed
25
    std::vector<size_t> arg_maxs(num_threads, 0);
Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
30
31
32
33
34
35
36
    int n_blocks = Threading::For<size_t>(
        0, array.size(), 1024,
        [&array, &arg_maxs](int i, size_t start, size_t end) {
          size_t arg_max = start;
          for (size_t j = start + 1; j < end; ++j) {
            if (array[j] > array[arg_max]) {
              arg_max = j;
            }
          }
          arg_maxs[i] = arg_max;
        });
Guolin Ke's avatar
Guolin Ke committed
37
    size_t ret = arg_maxs[0];
Guolin Ke's avatar
Guolin Ke committed
38
    for (int i = 1; i < n_blocks; ++i) {
Guolin Ke's avatar
Guolin Ke committed
39
40
41
42
43
44
      if (array[arg_maxs[i]] > array[ret]) {
        ret = arg_maxs[i];
      }
    }
    return ret;
  }
Guolin Ke's avatar
Guolin Ke committed
45
  inline static size_t ArgMax(const std::vector<VAL_T>& array) {
Guolin Ke's avatar
Guolin Ke committed
46
    if (array.empty()) {
Guolin Ke's avatar
Guolin Ke committed
47
48
      return 0;
    }
Guolin Ke's avatar
Guolin Ke committed
49
    if (array.size() > 1024) {
Guolin Ke's avatar
Guolin Ke committed
50
51
52
53
54
55
56
      return ArgMaxMT(array);
    } else {
      size_t arg_max = 0;
      for (size_t i = 1; i < array.size(); ++i) {
        if (array[i] > array[arg_max]) {
          arg_max = i;
        }
Guolin Ke's avatar
Guolin Ke committed
57
      }
Guolin Ke's avatar
Guolin Ke committed
58
      return arg_max;
Guolin Ke's avatar
Guolin Ke committed
59
60
61
62
    }
  }

  inline static size_t ArgMin(const std::vector<VAL_T>& array) {
Guolin Ke's avatar
Guolin Ke committed
63
    if (array.empty()) {
Guolin Ke's avatar
Guolin Ke committed
64
65
      return 0;
    }
Guolin Ke's avatar
Guolin Ke committed
66
    size_t arg_min = 0;
Guolin Ke's avatar
Guolin Ke committed
67
    for (size_t i = 1; i < array.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
68
69
      if (array[i] < array[arg_min]) {
        arg_min = i;
Guolin Ke's avatar
Guolin Ke committed
70
71
      }
    }
Guolin Ke's avatar
Guolin Ke committed
72
    return arg_min;
Guolin Ke's avatar
Guolin Ke committed
73
74
75
76
77
78
  }

  inline static size_t ArgMax(const VAL_T* array, size_t n) {
    if (n <= 0) {
      return 0;
    }
Guolin Ke's avatar
Guolin Ke committed
79
    size_t arg_max = 0;
Guolin Ke's avatar
Guolin Ke committed
80
    for (size_t i = 1; i < n; ++i) {
Guolin Ke's avatar
Guolin Ke committed
81
82
      if (array[i] > array[arg_max]) {
        arg_max = i;
Guolin Ke's avatar
Guolin Ke committed
83
84
      }
    }
Guolin Ke's avatar
Guolin Ke committed
85
    return arg_max;
Guolin Ke's avatar
Guolin Ke committed
86
87
88
89
90
91
  }

  inline static size_t ArgMin(const VAL_T* array, size_t n) {
    if (n <= 0) {
      return 0;
    }
Guolin Ke's avatar
Guolin Ke committed
92
    size_t arg_min = 0;
Guolin Ke's avatar
Guolin Ke committed
93
    for (size_t i = 1; i < n; ++i) {
Guolin Ke's avatar
Guolin Ke committed
94
95
      if (array[i] < array[arg_min]) {
        arg_min = i;
Guolin Ke's avatar
Guolin Ke committed
96
97
      }
    }
Guolin Ke's avatar
Guolin Ke committed
98
    return arg_min;
Guolin Ke's avatar
Guolin Ke committed
99
100
  }

Guolin Ke's avatar
Guolin Ke committed
101
102
103
104
105
  inline static void Partition(std::vector<VAL_T>* arr, int start, int end, int* l, int* r) {
    int i = start - 1;
    int j = end - 1;
    int p = i;
    int q = j;
106
107
108
    if (start >= end - 1) {
      *l = start - 1;
      *r = end;
Guolin Ke's avatar
Guolin Ke committed
109
      return;
Guolin Ke's avatar
Guolin Ke committed
110
    }
Guolin Ke's avatar
Guolin Ke committed
111
112
113
    std::vector<VAL_T>& ref = *arr;
    VAL_T v = ref[end - 1];
    for (;;) {
Guolin Ke's avatar
Guolin Ke committed
114
      while (ref[++i] > v) {}
115
116
117
118
119
120
121
122
      while (v > ref[--j]) {
        if (j == start) {
          break;
        }
      }
      if (i >= j) {
        break;
      }
Guolin Ke's avatar
Guolin Ke committed
123
      std::swap(ref[i], ref[j]);
124
125
126
127
128
129
130
131
      if (ref[i] == v) {
        p++;
        std::swap(ref[p], ref[i]);
      }
      if (v == ref[j]) {
        q--;
        std::swap(ref[j], ref[q]);
      }
Guolin Ke's avatar
Guolin Ke committed
132
133
134
135
    }
    std::swap(ref[i], ref[end - 1]);
    j = i - 1;
    i = i + 1;
136
137
138
139
140
141
    for (int k = start; k <= p; k++, j--) {
      std::swap(ref[k], ref[j]);
    }
    for (int k = end - 2; k >= q; k--, i++) {
      std::swap(ref[i], ref[k]);
    }
Guolin Ke's avatar
Guolin Ke committed
142
143
    *l = j;
    *r = i;
144
  }
Guolin Ke's avatar
Guolin Ke committed
145

Guolin Ke's avatar
Guolin Ke committed
146
  // Note: k refer to index here. e.g. k=0 means get the max number.
Guolin Ke's avatar
Guolin Ke committed
147
148
  inline static int ArgMaxAtK(std::vector<VAL_T>* arr, int start, int end, int k) {
    if (start >= end - 1) {
Guolin Ke's avatar
Guolin Ke committed
149
150
      return start;
    }
Guolin Ke's avatar
Guolin Ke committed
151
152
153
    int l = start;
    int r = end - 1;
    Partition(arr, start, end, &l, &r);
Guolin Ke's avatar
Guolin Ke committed
154
155
    // if find or all elements are the same.
    if ((k > l && k < r) || (l == start - 1 && r == end - 1)) {
Guolin Ke's avatar
Guolin Ke committed
156
157
      return k;
    } else if (k <= l) {
Guolin Ke's avatar
Guolin Ke committed
158
      return ArgMaxAtK(arr, start, l + 1, k);
Guolin Ke's avatar
Guolin Ke committed
159
160
    } else {
      return ArgMaxAtK(arr, r, end, k);
Guolin Ke's avatar
Guolin Ke committed
161
162
163
    }
  }

Guolin Ke's avatar
Guolin Ke committed
164
  // Note: k is 1-based here. e.g. k=3 means get the top-3 numbers.
Guolin Ke's avatar
Guolin Ke committed
165
  inline static void MaxK(const std::vector<VAL_T>& array, int k, std::vector<VAL_T>* out) {
Guolin Ke's avatar
Guolin Ke committed
166
167
168
169
170
171
172
    out->clear();
    if (k <= 0) {
      return;
    }
    for (auto val : array) {
      out->push_back(val);
    }
Guolin Ke's avatar
Guolin Ke committed
173
    if (static_cast<size_t>(k) >= array.size()) {
Guolin Ke's avatar
Guolin Ke committed
174
175
      return;
    }
Guolin Ke's avatar
Guolin Ke committed
176
    ArgMaxAtK(out, 0, static_cast<int>(out->size()), k - 1);
Guolin Ke's avatar
Guolin Ke committed
177
178
179
    out->erase(out->begin() + k, out->end());
  }

Guolin Ke's avatar
Guolin Ke committed
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
  inline static void Assign(std::vector<VAL_T>* array, VAL_T t, size_t n) {
    array->resize(n);
    for (size_t i = 0; i < array->size(); ++i) {
      (*array)[i] = t;
    }
  }

  inline static bool CheckAllZero(const std::vector<VAL_T>& array) {
    for (size_t i = 0; i < array.size(); ++i) {
      if (array[i] != VAL_T(0)) {
        return false;
      }
    }
    return true;
  }

Guolin Ke's avatar
Guolin Ke committed
196
197
198
199
200
201
202
203
  inline static bool CheckAll(const std::vector<VAL_T>& array, VAL_T t) {
    for (size_t i = 0; i < array.size(); ++i) {
      if (array[i] != t) {
        return false;
      }
    }
    return true;
  }
Guolin Ke's avatar
Guolin Ke committed
204
205
206
207
};

}  // namespace LightGBM

208
#endif  // LIGHTGBM_INCLUDE_LIGHTGBM_UTILS_ARRAY_ARGS_H_