array_args.h 4.98 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
7
#ifndef LIGHTGBM_UTILS_ARRAY_AGRS_H_
#define LIGHTGBM_UTILS_ARRAY_AGRS_H_

Guolin Ke's avatar
Guolin Ke committed
8
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
9

10
11
12
13
#include <algorithm>
#include <utility>
#include <vector>

Guolin Ke's avatar
Guolin Ke committed
14
15
16
17
18
19
20
namespace LightGBM {

/*!
* \brief Contains some operation for a array, e.g. ArgMax, TopK.
*/
template<typename VAL_T>
class ArrayArgs {
Nikita Titov's avatar
Nikita Titov committed
21
 public:
Guolin Ke's avatar
Guolin Ke committed
22
23
24
25
26
27
28
29
30
  inline static size_t ArgMaxMT(const std::vector<VAL_T>& array) {
    int num_threads = 1;
#pragma omp parallel
#pragma omp master
    {
      num_threads = omp_get_num_threads();
    }
    int step = std::max(1, (static_cast<int>(array.size()) + num_threads - 1) / num_threads);
    std::vector<size_t> arg_maxs(num_threads, 0);
31
    #pragma omp parallel for schedule(static, 1)
Guolin Ke's avatar
Guolin Ke committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    for (int i = 0; i < num_threads; ++i) {
      size_t start = step * i;
      if (start >= array.size()) { continue; }
      size_t end = std::min(array.size(), start + step);
      size_t arg_max = start;
      for (size_t j = start + 1; j < end; ++j) {
        if (array[j] > array[arg_max]) {
          arg_max = j;
        }
      }
      arg_maxs[i] = arg_max;
    }
    size_t ret = arg_maxs[0];
    for (int i = 1; i < num_threads; ++i) {
      if (array[arg_maxs[i]] > array[ret]) {
        ret = arg_maxs[i];
      }
    }
    return ret;
  }
Guolin Ke's avatar
Guolin Ke committed
52
  inline static size_t ArgMax(const std::vector<VAL_T>& array) {
Guolin Ke's avatar
Guolin Ke committed
53
    if (array.empty()) {
Guolin Ke's avatar
Guolin Ke committed
54
55
      return 0;
    }
Guolin Ke's avatar
Guolin Ke committed
56
    if (array.size() > 1024) {
Guolin Ke's avatar
Guolin Ke committed
57
58
59
60
61
62
63
      return ArgMaxMT(array);
    } else {
      size_t arg_max = 0;
      for (size_t i = 1; i < array.size(); ++i) {
        if (array[i] > array[arg_max]) {
          arg_max = i;
        }
Guolin Ke's avatar
Guolin Ke committed
64
      }
Guolin Ke's avatar
Guolin Ke committed
65
      return arg_max;
Guolin Ke's avatar
Guolin Ke committed
66
67
68
69
    }
  }

  inline static size_t ArgMin(const std::vector<VAL_T>& array) {
Guolin Ke's avatar
Guolin Ke committed
70
    if (array.empty()) {
Guolin Ke's avatar
Guolin Ke committed
71
72
      return 0;
    }
Guolin Ke's avatar
Guolin Ke committed
73
    size_t arg_min = 0;
Guolin Ke's avatar
Guolin Ke committed
74
    for (size_t i = 1; i < array.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
75
76
      if (array[i] < array[arg_min]) {
        arg_min = i;
Guolin Ke's avatar
Guolin Ke committed
77
78
      }
    }
Guolin Ke's avatar
Guolin Ke committed
79
    return arg_min;
Guolin Ke's avatar
Guolin Ke committed
80
81
82
83
84
85
  }

  inline static size_t ArgMax(const VAL_T* array, size_t n) {
    if (n <= 0) {
      return 0;
    }
Guolin Ke's avatar
Guolin Ke committed
86
    size_t arg_max = 0;
Guolin Ke's avatar
Guolin Ke committed
87
    for (size_t i = 1; i < n; ++i) {
Guolin Ke's avatar
Guolin Ke committed
88
89
      if (array[i] > array[arg_max]) {
        arg_max = i;
Guolin Ke's avatar
Guolin Ke committed
90
91
      }
    }
Guolin Ke's avatar
Guolin Ke committed
92
    return arg_max;
Guolin Ke's avatar
Guolin Ke committed
93
94
95
96
97
98
  }

  inline static size_t ArgMin(const VAL_T* array, size_t n) {
    if (n <= 0) {
      return 0;
    }
Guolin Ke's avatar
Guolin Ke committed
99
    size_t arg_min = 0;
Guolin Ke's avatar
Guolin Ke committed
100
    for (size_t i = 1; i < n; ++i) {
Guolin Ke's avatar
Guolin Ke committed
101
102
      if (array[i] < array[arg_min]) {
        arg_min = i;
Guolin Ke's avatar
Guolin Ke committed
103
104
      }
    }
Guolin Ke's avatar
Guolin Ke committed
105
    return arg_min;
Guolin Ke's avatar
Guolin Ke committed
106
107
  }

Guolin Ke's avatar
Guolin Ke committed
108
109
110
111
112
113
114
  inline static void Partition(std::vector<VAL_T>* arr, int start, int end, int* l, int* r) {
    int i = start - 1;
    int j = end - 1;
    int p = i;
    int q = j;
    if (start >= end) {
      return;
Guolin Ke's avatar
Guolin Ke committed
115
    }
Guolin Ke's avatar
Guolin Ke committed
116
117
118
    std::vector<VAL_T>& ref = *arr;
    VAL_T v = ref[end - 1];
    for (;;) {
Guolin Ke's avatar
Guolin Ke committed
119
      while (ref[++i] > v) {}
Guolin Ke's avatar
Guolin Ke committed
120
121
122
123
124
125
126
127
128
129
130
131
132
      while (v > ref[--j]) { if (j == start) { break; } }
      if (i >= j) { break; }
      std::swap(ref[i], ref[j]);
      if (ref[i] == v) { p++; std::swap(ref[p], ref[i]); }
      if (v == ref[j]) { q--; std::swap(ref[j], ref[q]); }
    }
    std::swap(ref[i], ref[end - 1]);
    j = i - 1;
    i = i + 1;
    for (int k = start; k <= p; k++, j--) { std::swap(ref[k], ref[j]); }
    for (int k = end - 2; k >= q; k--, i++) { std::swap(ref[i], ref[k]); }
    *l = j;
    *r = i;
133
  }
Guolin Ke's avatar
Guolin Ke committed
134

Guolin Ke's avatar
Guolin Ke committed
135
  // Note: k refer to index here. e.g. k=0 means get the max number.
Guolin Ke's avatar
Guolin Ke committed
136
137
  inline static int ArgMaxAtK(std::vector<VAL_T>* arr, int start, int end, int k) {
    if (start >= end - 1) {
Guolin Ke's avatar
Guolin Ke committed
138
139
      return start;
    }
Guolin Ke's avatar
Guolin Ke committed
140
141
142
    int l = start;
    int r = end - 1;
    Partition(arr, start, end, &l, &r);
Guolin Ke's avatar
Guolin Ke committed
143
144
    // if find or all elements are the same.
    if ((k > l && k < r) || (l == start - 1 && r == end - 1)) {
Guolin Ke's avatar
Guolin Ke committed
145
146
      return k;
    } else if (k <= l) {
Guolin Ke's avatar
Guolin Ke committed
147
      return ArgMaxAtK(arr, start, l + 1, k);
Guolin Ke's avatar
Guolin Ke committed
148
149
    } else {
      return ArgMaxAtK(arr, r, end, k);
Guolin Ke's avatar
Guolin Ke committed
150
151
152
    }
  }

Guolin Ke's avatar
Guolin Ke committed
153
  // Note: k is 1-based here. e.g. k=3 means get the top-3 numbers.
Guolin Ke's avatar
Guolin Ke committed
154
  inline static void MaxK(const std::vector<VAL_T>& array, int k, std::vector<VAL_T>* out) {
Guolin Ke's avatar
Guolin Ke committed
155
156
157
158
159
160
161
    out->clear();
    if (k <= 0) {
      return;
    }
    for (auto val : array) {
      out->push_back(val);
    }
Guolin Ke's avatar
Guolin Ke committed
162
    if (static_cast<size_t>(k) >= array.size()) {
Guolin Ke's avatar
Guolin Ke committed
163
164
      return;
    }
Guolin Ke's avatar
Guolin Ke committed
165
    ArgMaxAtK(out, 0, static_cast<int>(out->size()), k - 1);
Guolin Ke's avatar
Guolin Ke committed
166
167
168
    out->erase(out->begin() + k, out->end());
  }

Guolin Ke's avatar
Guolin Ke committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
  inline static void Assign(std::vector<VAL_T>* array, VAL_T t, size_t n) {
    array->resize(n);
    for (size_t i = 0; i < array->size(); ++i) {
      (*array)[i] = t;
    }
  }

  inline static bool CheckAllZero(const std::vector<VAL_T>& array) {
    for (size_t i = 0; i < array.size(); ++i) {
      if (array[i] != VAL_T(0)) {
        return false;
      }
    }
    return true;
  }

Guolin Ke's avatar
Guolin Ke committed
185
186
187
188
189
190
191
192
  inline static bool CheckAll(const std::vector<VAL_T>& array, VAL_T t) {
    for (size_t i = 0; i < array.size(); ++i) {
      if (array[i] != t) {
        return false;
      }
    }
    return true;
  }
Guolin Ke's avatar
Guolin Ke committed
193
194
195
196
};

}  // namespace LightGBM

Guolin Ke's avatar
Guolin Ke committed
197
#endif   // LightGBM_UTILS_ARRAY_AGRS_H_
Guolin Ke's avatar
Guolin Ke committed
198