common.h 22.4 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
#ifndef LIGHTGBM_UTILS_COMMON_FUN_H_
#define LIGHTGBM_UTILS_COMMON_FUN_H_

#include <LightGBM/utils/log.h>
9
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
10

11
#include <limits>
Guolin Ke's avatar
Guolin Ke committed
12
#include <string>
13
#include <algorithm>
14
#include <cmath>
15
16
#include <cstdint>
#include <cstdio>
Guolin Ke's avatar
Guolin Ke committed
17
#include <functional>
18
#include <iomanip>
19
#include <iterator>
20
21
#include <memory>
#include <sstream>
Guolin Ke's avatar
Guolin Ke committed
22
#include <type_traits>
23
24
#include <utility>
#include <vector>
Guolin Ke's avatar
Guolin Ke committed
25

26
27
28
29
#ifdef _MSC_VER
#include "intrin.h"
#endif

Guolin Ke's avatar
Guolin Ke committed
30
31
32
33
namespace LightGBM {

namespace Common {

34
inline static char tolower(char in) {
Guolin Ke's avatar
Guolin Ke committed
35
36
37
38
39
  if (in <= 'Z' && in >= 'A')
    return in - ('Z' - 'z');
  return in;
}

40
inline static std::string Trim(std::string str) {
Guolin Ke's avatar
Guolin Ke committed
41
  if (str.empty()) {
Guolin Ke's avatar
Guolin Ke committed
42
43
44
45
46
47
48
    return str;
  }
  str.erase(str.find_last_not_of(" \f\n\r\t\v") + 1);
  str.erase(0, str.find_first_not_of(" \f\n\r\t\v"));
  return str;
}

49
inline static std::string RemoveQuotationSymbol(std::string str) {
Guolin Ke's avatar
Guolin Ke committed
50
  if (str.empty()) {
51
52
53
54
55
56
    return str;
  }
  str.erase(str.find_last_not_of("'\"") + 1);
  str.erase(0, str.find_first_not_of("'\""));
  return str;
}
Guolin Ke's avatar
Guolin Ke committed
57

Guolin Ke's avatar
Guolin Ke committed
58
59
60
61
62
63
64
inline static bool StartsWith(const std::string& str, const std::string prefix) {
  if (str.substr(0, prefix.size()) == prefix) {
    return true;
  } else {
    return false;
  }
}
Guolin Ke's avatar
Guolin Ke committed
65

Guolin Ke's avatar
Guolin Ke committed
66
inline static std::vector<std::string> Split(const char* c_str, char delimiter) {
Guolin Ke's avatar
Guolin Ke committed
67
  std::vector<std::string> ret;
Guolin Ke's avatar
Guolin Ke committed
68
69
  std::string str(c_str);
  size_t i = 0;
Guolin Ke's avatar
Guolin Ke committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
  size_t pos = 0;
  while (pos < str.length()) {
    if (str[pos] == delimiter) {
      if (i < pos) {
        ret.push_back(str.substr(i, pos - i));
      }
      ++pos;
      i = pos;
    } else {
      ++pos;
    }
  }
  if (i < pos) {
    ret.push_back(str.substr(i));
  }
  return ret;
}

inline static std::vector<std::string> SplitLines(const char* c_str) {
  std::vector<std::string> ret;
  std::string str(c_str);
  size_t i = 0;
  size_t pos = 0;
  while (pos < str.length()) {
    if (str[pos] == '\n' || str[pos] == '\r') {
      if (i < pos) {
        ret.push_back(str.substr(i, pos - i));
      }
      // skip the line endings
      while (str[pos] == '\n' || str[pos] == '\r') ++pos;
      // new begin
      i = pos;
    } else {
      ++pos;
    }
  }
  if (i < pos) {
    ret.push_back(str.substr(i));
Guolin Ke's avatar
Guolin Ke committed
108
109
110
111
  }
  return ret;
}

Guolin Ke's avatar
Guolin Ke committed
112
113
114
115
inline static std::vector<std::string> Split(const char* c_str, const char* delimiters) {
  std::vector<std::string> ret;
  std::string str(c_str);
  size_t i = 0;
Guolin Ke's avatar
Guolin Ke committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
  size_t pos = 0;
  while (pos < str.length()) {
    bool met_delimiters = false;
    for (int j = 0; delimiters[j] != '\0'; ++j) {
      if (str[pos] == delimiters[j]) {
        met_delimiters = true;
        break;
      }
    }
    if (met_delimiters) {
      if (i < pos) {
        ret.push_back(str.substr(i, pos - i));
      }
      ++pos;
      i = pos;
    } else {
      ++pos;
    }
  }
  if (i < pos) {
    ret.push_back(str.substr(i));
Guolin Ke's avatar
Guolin Ke committed
137
138
139
140
  }
  return ret;
}

141
142
143
144
template<typename T>
inline static const char* Atoi(const char* p, T* out) {
  int sign;
  T value;
Guolin Ke's avatar
Guolin Ke committed
145
146
147
148
149
150
151
  while (*p == ' ') {
    ++p;
  }
  sign = 1;
  if (*p == '-') {
    sign = -1;
    ++p;
152
  } else if (*p == '+') {
Guolin Ke's avatar
Guolin Ke committed
153
154
155
156
157
    ++p;
  }
  for (value = 0; *p >= '0' && *p <= '9'; ++p) {
    value = value * 10 + (*p - '0');
  }
158
  *out = static_cast<T>(sign * value);
Guolin Ke's avatar
Guolin Ke committed
159
160
161
162
163
164
  while (*p == ' ') {
    ++p;
  }
  return p;
}

165
template<typename T>
166
167
168
169
170
171
172
173
174
175
176
177
178
179
inline static double Pow(T base, int power) {
  if (power < 0) {
    return 1.0 / Pow(base, -power);
  } else if (power == 0) {
    return 1;
  } else if (power % 2 == 0) {
    return Pow(base*base, power / 2);
  } else if (power % 3 == 0) {
    return Pow(base*base*base, power / 3);
  } else {
    return base * Pow(base, power - 1);
  }
}

180
inline static const char* Atof(const char* p, double* out) {
Guolin Ke's avatar
Guolin Ke committed
181
  int frac;
182
  double sign, value, scale;
Guolin Ke's avatar
Guolin Ke committed
183
  *out = NAN;
Guolin Ke's avatar
Guolin Ke committed
184
185
186
187
188
  // Skip leading white space, if any.
  while (*p == ' ') {
    ++p;
  }
  // Get sign, if any.
189
  sign = 1.0;
Guolin Ke's avatar
Guolin Ke committed
190
  if (*p == '-') {
191
    sign = -1.0;
Guolin Ke's avatar
Guolin Ke committed
192
    ++p;
193
  } else if (*p == '+') {
Guolin Ke's avatar
Guolin Ke committed
194
195
196
    ++p;
  }

Guolin Ke's avatar
Guolin Ke committed
197
198
199
  // is a number
  if ((*p >= '0' && *p <= '9') || *p == '.' || *p == 'e' || *p == 'E') {
    // Get digits before decimal point or exponent, if any.
200
201
    for (value = 0.0; *p >= '0' && *p <= '9'; ++p) {
      value = value * 10.0 + (*p - '0');
Guolin Ke's avatar
Guolin Ke committed
202
    }
Guolin Ke's avatar
Guolin Ke committed
203

Guolin Ke's avatar
Guolin Ke committed
204
205
    // Get digits after decimal point, if any.
    if (*p == '.') {
206
207
      double right = 0.0;
      int nn = 0;
Guolin Ke's avatar
Guolin Ke committed
208
      ++p;
Guolin Ke's avatar
Guolin Ke committed
209
      while (*p >= '0' && *p <= '9') {
210
211
        right = (*p - '0') + right * 10.0;
        ++nn;
Guolin Ke's avatar
Guolin Ke committed
212
213
        ++p;
      }
214
      value += right / Pow(10.0, nn);
Guolin Ke's avatar
Guolin Ke committed
215
216
    }

Guolin Ke's avatar
Guolin Ke committed
217
218
    // Handle exponent, if any.
    frac = 0;
219
    scale = 1.0;
Guolin Ke's avatar
Guolin Ke committed
220
    if ((*p == 'e') || (*p == 'E')) {
Guolin Ke's avatar
Guolin Ke committed
221
      uint32_t expon;
Guolin Ke's avatar
Guolin Ke committed
222
      // Get sign of exponent, if any.
Guolin Ke's avatar
Guolin Ke committed
223
      ++p;
Guolin Ke's avatar
Guolin Ke committed
224
225
226
227
228
229
230
231
232
233
      if (*p == '-') {
        frac = 1;
        ++p;
      } else if (*p == '+') {
        ++p;
      }
      // Get digits of exponent, if any.
      for (expon = 0; *p >= '0' && *p <= '9'; ++p) {
        expon = expon * 10 + (*p - '0');
      }
234
235
236
      if (expon > 308) expon = 308;
      // Calculate scaling factor.
      while (expon >= 50) { scale *= 1E50; expon -= 50; }
Guolin Ke's avatar
Guolin Ke committed
237
      while (expon >= 8) { scale *= 1E8;  expon -= 8; }
238
      while (expon > 0) { scale *= 10.0; expon -= 1; }
Guolin Ke's avatar
Guolin Ke committed
239
    }
Guolin Ke's avatar
Guolin Ke committed
240
241
242
    // Return signed and scaled floating point result.
    *out = sign * (frac ? (value / scale) : (value * scale));
  } else {
243
    size_t cnt = 0;
244
    while (*(p + cnt) != '\0' && *(p + cnt) != ' '
245
246
247
           && *(p + cnt) != '\t' && *(p + cnt) != ','
           && *(p + cnt) != '\n' && *(p + cnt) != '\r'
           && *(p + cnt) != ':') {
248
249
      ++cnt;
    }
250
    if (cnt > 0) {
Guolin Ke's avatar
Guolin Ke committed
251
      std::string tmp_str(p, cnt);
Guolin Ke's avatar
Guolin Ke committed
252
      std::transform(tmp_str.begin(), tmp_str.end(), tmp_str.begin(), Common::tolower);
zhangjin's avatar
zhangjin committed
253
254
      if (tmp_str == std::string("na") || tmp_str == std::string("nan") ||
          tmp_str == std::string("null")) {
Guolin Ke's avatar
Guolin Ke committed
255
        *out = NAN;
256
      } else if (tmp_str == std::string("inf") || tmp_str == std::string("infinity")) {
257
        *out = sign * 1e308;
258
      } else {
259
        Log::Fatal("Unknown token %s in data file", tmp_str.c_str());
Guolin Ke's avatar
Guolin Ke committed
260
261
      }
      p += cnt;
262
    }
Guolin Ke's avatar
Guolin Ke committed
263
  }
Guolin Ke's avatar
Guolin Ke committed
264

Guolin Ke's avatar
Guolin Ke committed
265
266
267
  while (*p == ' ') {
    ++p;
  }
Guolin Ke's avatar
Guolin Ke committed
268

Guolin Ke's avatar
Guolin Ke committed
269
270
271
  return p;
}

272
inline static bool AtoiAndCheck(const char* p, int* out) {
273
274
275
276
277
278
279
  const char* after = Atoi(p, out);
  if (*after != '\0') {
    return false;
  }
  return true;
}

280
inline static bool AtofAndCheck(const char* p, double* out) {
281
282
283
284
285
286
287
  const char* after = Atof(p, out);
  if (*after != '\0') {
    return false;
  }
  return true;
}

288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
inline static unsigned CountDecimalDigit32(uint32_t n) {
#if defined(_MSC_VER) || defined(__GNUC__)
  static const uint32_t powers_of_10[] = {
    0,
    10,
    100,
    1000,
    10000,
    100000,
    1000000,
    10000000,
    100000000,
    1000000000
  };
#ifdef _MSC_VER
  unsigned long i = 0;
  _BitScanReverse(&i, n | 1);
  uint32_t t = (i + 1) * 1233 >> 12;
#elif __GNUC__
  uint32_t t = (32 - __builtin_clz(n | 1)) * 1233 >> 12;
#endif
  return t - (n < powers_of_10[t]) + 1;
#else
  if (n < 10) return 1;
  if (n < 100) return 2;
  if (n < 1000) return 3;
  if (n < 10000) return 4;
  if (n < 100000) return 5;
  if (n < 1000000) return 6;
  if (n < 10000000) return 7;
  if (n < 100000000) return 8;
  if (n < 1000000000) return 9;
  return 10;
#endif
}

inline static void Uint32ToStr(uint32_t value, char* buffer) {
  const char kDigitsLut[200] = {
326
327
328
329
330
331
332
333
334
335
    '0', '0', '0', '1', '0', '2', '0', '3', '0', '4', '0', '5', '0', '6', '0', '7', '0', '8', '0', '9',
    '1', '0', '1', '1', '1', '2', '1', '3', '1', '4', '1', '5', '1', '6', '1', '7', '1', '8', '1', '9',
    '2', '0', '2', '1', '2', '2', '2', '3', '2', '4', '2', '5', '2', '6', '2', '7', '2', '8', '2', '9',
    '3', '0', '3', '1', '3', '2', '3', '3', '3', '4', '3', '5', '3', '6', '3', '7', '3', '8', '3', '9',
    '4', '0', '4', '1', '4', '2', '4', '3', '4', '4', '4', '5', '4', '6', '4', '7', '4', '8', '4', '9',
    '5', '0', '5', '1', '5', '2', '5', '3', '5', '4', '5', '5', '5', '6', '5', '7', '5', '8', '5', '9',
    '6', '0', '6', '1', '6', '2', '6', '3', '6', '4', '6', '5', '6', '6', '6', '7', '6', '8', '6', '9',
    '7', '0', '7', '1', '7', '2', '7', '3', '7', '4', '7', '5', '7', '6', '7', '7', '7', '8', '7', '9',
    '8', '0', '8', '1', '8', '2', '8', '3', '8', '4', '8', '5', '8', '6', '8', '7', '8', '8', '8', '9',
    '9', '0', '9', '1', '9', '2', '9', '3', '9', '4', '9', '5', '9', '6', '9', '7', '9', '8', '9', '9'
336
337
338
339
340
341
342
343
344
345
346
347
348
  };
  unsigned digit = CountDecimalDigit32(value);
  buffer += digit;
  *buffer = '\0';

  while (value >= 100) {
    const unsigned i = (value % 100) << 1;
    value /= 100;
    *--buffer = kDigitsLut[i + 1];
    *--buffer = kDigitsLut[i];
  }

  if (value < 10) {
349
    *--buffer = static_cast<char>(value) + '0';
350
  } else {
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
    const unsigned i = value << 1;
    *--buffer = kDigitsLut[i + 1];
    *--buffer = kDigitsLut[i];
  }
}

inline static void Int32ToStr(int32_t value, char* buffer) {
  uint32_t u = static_cast<uint32_t>(value);
  if (value < 0) {
    *buffer++ = '-';
    u = ~u + 1;
  }
  Uint32ToStr(u, buffer);
}

366
inline static void DoubleToStr(double value, char* buffer, size_t
367
368
369
370
371
372
373
374
375
376
377
                               #ifdef _MSC_VER
                               buffer_len
                               #endif
) {
  #ifdef _MSC_VER
  sprintf_s(buffer, buffer_len, "%.17g", value);
  #else
  sprintf(buffer, "%.17g", value);
  #endif
}

Guolin Ke's avatar
Guolin Ke committed
378
379
380
381
382
383
384
385
386
387
388
389
390
391
inline static const char* SkipSpaceAndTab(const char* p) {
  while (*p == ' ' || *p == '\t') {
    ++p;
  }
  return p;
}

inline static const char* SkipReturn(const char* p) {
  while (*p == '\n' || *p == '\r' || *p == ' ') {
    ++p;
  }
  return p;
}

Guolin Ke's avatar
Guolin Ke committed
392
393
template<typename T, typename T2>
inline static std::vector<T2> ArrayCast(const std::vector<T>& arr) {
394
  std::vector<T2> ret(arr.size());
Guolin Ke's avatar
Guolin Ke committed
395
  for (size_t i = 0; i < arr.size(); ++i) {
396
    ret[i] = static_cast<T2>(arr[i]);
Guolin Ke's avatar
Guolin Ke committed
397
  }
Guolin Ke's avatar
Guolin Ke committed
398
  return ret;
Guolin Ke's avatar
Guolin Ke committed
399
400
}

401
402
template<typename T, bool is_float, bool is_unsign>
struct __TToStringHelperFast {
403
  void operator()(T value, char* buffer, size_t) const {
404
405
406
407
408
409
    Int32ToStr(value, buffer);
  }
};

template<typename T>
struct __TToStringHelperFast<T, true, false> {
410
  void operator()(T value, char* buffer, size_t
411
412
413
414
415
416
417
418
419
420
421
422
423
424
                  #ifdef _MSC_VER
                  buf_len
                  #endif
                  ) const {
    #ifdef _MSC_VER
    sprintf_s(buffer, buf_len, "%g", value);
    #else
    sprintf(buffer, "%g", value);
    #endif
  }
};

template<typename T>
struct __TToStringHelperFast<T, false, true> {
425
  void operator()(T value, char* buffer, size_t) const {
426
427
428
429
    Uint32ToStr(value, buffer);
  }
};

430
template<typename T>
431
432
inline static std::string ArrayToStringFast(const std::vector<T>& arr, size_t n) {
  if (arr.empty() || n == 0) {
433
    return std::string("");
Guolin Ke's avatar
Guolin Ke committed
434
  }
435
436
437
  __TToStringHelperFast<T, std::is_floating_point<T>::value, std::is_unsigned<T>::value> helper;
  const size_t buf_len = 16;
  std::vector<char> buffer(buf_len);
438
  std::stringstream str_buf;
439
440
441
442
443
  helper(arr[0], buffer.data(), buf_len);
  str_buf << buffer.data();
  for (size_t i = 1; i < std::min(n, arr.size()); ++i) {
    helper(arr[i], buffer.data(), buf_len);
    str_buf << ' ' << buffer.data();
Guolin Ke's avatar
Guolin Ke committed
444
  }
445
  return str_buf.str();
Guolin Ke's avatar
Guolin Ke committed
446
447
}

448
inline static std::string ArrayToString(const std::vector<double>& arr, size_t n) {
Guolin Ke's avatar
Guolin Ke committed
449
450
451
  if (arr.empty() || n == 0) {
    return std::string("");
  }
452
453
  const size_t buf_len = 32;
  std::vector<char> buffer(buf_len);
Guolin Ke's avatar
Guolin Ke committed
454
  std::stringstream str_buf;
455
456
  DoubleToStr(arr[0], buffer.data(), buf_len);
  str_buf << buffer.data();
Guolin Ke's avatar
Guolin Ke committed
457
  for (size_t i = 1; i < std::min(n, arr.size()); ++i) {
458
459
    DoubleToStr(arr[i], buffer.data(), buf_len);
    str_buf << ' ' << buffer.data();
Guolin Ke's avatar
Guolin Ke committed
460
461
462
463
  }
  return str_buf.str();
}

464
465
466
template<typename T, bool is_float>
struct __StringToTHelper {
  T operator()(const std::string& str) const {
467
468
469
    T ret = 0;
    Atoi(str.c_str(), &ret);
    return ret;
470
471
472
473
474
475
476
477
478
479
  }
};

template<typename T>
struct __StringToTHelper<T, true> {
  T operator()(const std::string& str) const {
    return static_cast<T>(std::stod(str));
  }
};

Guolin Ke's avatar
Guolin Ke committed
480
template<typename T>
481
inline static std::vector<T> StringToArray(const std::string& str, char delimiter) {
Guolin Ke's avatar
Guolin Ke committed
482
  std::vector<std::string> strs = Split(str.c_str(), delimiter);
483
484
  std::vector<T> ret;
  ret.reserve(strs.size());
485
  __StringToTHelper<T, std::is_floating_point<T>::value> helper;
486
487
  for (const auto& s : strs) {
    ret.push_back(helper(s));
Guolin Ke's avatar
Guolin Ke committed
488
489
490
491
  }
  return ret;
}

Guolin Ke's avatar
Guolin Ke committed
492
template<typename T>
493
494
495
496
497
498
inline static std::vector<T> StringToArray(const std::string& str, int n) {
  if (n == 0) {
    return std::vector<T>();
  }
  std::vector<std::string> strs = Split(str.c_str(), ' ');
  CHECK(strs.size() == static_cast<size_t>(n));
Guolin Ke's avatar
Guolin Ke committed
499
  std::vector<T> ret;
500
501
502
503
  ret.reserve(strs.size());
  __StringToTHelper<T, std::is_floating_point<T>::value> helper;
  for (const auto& s : strs) {
    ret.push_back(helper(s));
Guolin Ke's avatar
Guolin Ke committed
504
505
506
507
  }
  return ret;
}

508
509
510
511
512
513
514
515
516
517
518
519
template<typename T, bool is_float>
struct __StringToTHelperFast {
  const char* operator()(const char*p, T* out) const {
    return Atoi(p, out);
  }
};

template<typename T>
struct __StringToTHelperFast<T, true> {
  const char* operator()(const char*p, T* out) const {
    double tmp = 0.0f;
    auto ret = Atof(p, &tmp);
520
    *out = static_cast<T>(tmp);
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
    return ret;
  }
};

template<typename T>
inline static std::vector<T> StringToArrayFast(const std::string& str, int n) {
  if (n == 0) {
    return std::vector<T>();
  }
  auto p_str = str.c_str();
  __StringToTHelperFast<T, std::is_floating_point<T>::value> helper;
  std::vector<T> ret(n);
  for (int i = 0; i < n; ++i) {
    p_str = helper(p_str, &ret[i]);
  }
  return ret;
}

539
template<typename T>
Guolin Ke's avatar
Guolin Ke committed
540
inline static std::string Join(const std::vector<T>& strs, const char* delimiter) {
Guolin Ke's avatar
Guolin Ke committed
541
  if (strs.empty()) {
Guolin Ke's avatar
Guolin Ke committed
542
543
    return std::string("");
  }
544
  std::stringstream str_buf;
545
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
546
  str_buf << strs[0];
Guolin Ke's avatar
Guolin Ke committed
547
  for (size_t i = 1; i < strs.size(); ++i) {
548
549
    str_buf << delimiter;
    str_buf << strs[i];
Guolin Ke's avatar
Guolin Ke committed
550
  }
551
  return str_buf.str();
Guolin Ke's avatar
Guolin Ke committed
552
553
}

554
template<typename T>
Guolin Ke's avatar
Guolin Ke committed
555
inline static std::string Join(const std::vector<T>& strs, size_t start, size_t end, const char* delimiter) {
Guolin Ke's avatar
Guolin Ke committed
556
557
558
  if (end - start <= 0) {
    return std::string("");
  }
Guolin Ke's avatar
Guolin Ke committed
559
560
  start = std::min(start, static_cast<size_t>(strs.size()) - 1);
  end = std::min(end, static_cast<size_t>(strs.size()));
561
  std::stringstream str_buf;
562
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
563
  str_buf << strs[start];
Guolin Ke's avatar
Guolin Ke committed
564
  for (size_t i = start + 1; i < end; ++i) {
565
566
    str_buf << delimiter;
    str_buf << strs[i];
Guolin Ke's avatar
Guolin Ke committed
567
  }
568
  return str_buf.str();
Guolin Ke's avatar
Guolin Ke committed
569
570
}

571
inline static int64_t Pow2RoundUp(int64_t x) {
Guolin Ke's avatar
Guolin Ke committed
572
573
574
575
576
577
578
579
580
581
  int64_t t = 1;
  for (int i = 0; i < 64; ++i) {
    if (t >= x) {
      return t;
    }
    t <<= 1;
  }
  return 0;
}

582
583
584
585
/*!
 * \brief Do inplace softmax transformaton on p_rec
 * \param p_rec The input/output vector of the values.
 */
586
inline static void Softmax(std::vector<double>* p_rec) {
587
588
  std::vector<double> &rec = *p_rec;
  double wmax = rec[0];
589
590
591
  for (size_t i = 1; i < rec.size(); ++i) {
    wmax = std::max(rec[i], wmax);
  }
592
  double wsum = 0.0f;
593
594
595
596
597
  for (size_t i = 0; i < rec.size(); ++i) {
    rec[i] = std::exp(rec[i] - wmax);
    wsum += rec[i];
  }
  for (size_t i = 0; i < rec.size(); ++i) {
598
    rec[i] /= static_cast<double>(wsum);
599
600
601
  }
}

602
inline static void Softmax(const double* input, double* output, int len) {
Guolin Ke's avatar
Guolin Ke committed
603
  double wmax = input[0];
604
  for (int i = 1; i < len; ++i) {
Guolin Ke's avatar
Guolin Ke committed
605
    wmax = std::max(input[i], wmax);
606
607
608
  }
  double wsum = 0.0f;
  for (int i = 0; i < len; ++i) {
Guolin Ke's avatar
Guolin Ke committed
609
610
    output[i] = std::exp(input[i] - wmax);
    wsum += output[i];
611
612
  }
  for (int i = 0; i < len; ++i) {
Guolin Ke's avatar
Guolin Ke committed
613
    output[i] /= static_cast<double>(wsum);
614
615
616
  }
}

Guolin Ke's avatar
Guolin Ke committed
617
618
619
620
621
template<typename T>
std::vector<const T*> ConstPtrInVectorWrapper(const std::vector<std::unique_ptr<T>>& input) {
  std::vector<const T*> ret;
  for (size_t i = 0; i < input.size(); ++i) {
    ret.push_back(input.at(i).get());
622
  }
Guolin Ke's avatar
Guolin Ke committed
623
  return ret;
624
625
}

Guolin Ke's avatar
Guolin Ke committed
626
template<typename T1, typename T2>
627
inline static void SortForPair(std::vector<T1>& keys, std::vector<T2>& values, size_t start, bool is_reverse = false) {
Guolin Ke's avatar
Guolin Ke committed
628
629
630
631
632
  std::vector<std::pair<T1, T2>> arr;
  for (size_t i = start; i < keys.size(); ++i) {
    arr.emplace_back(keys[i], values[i]);
  }
  if (!is_reverse) {
633
    std::stable_sort(arr.begin(), arr.end(), [](const std::pair<T1, T2>& a, const std::pair<T1, T2>& b) {
Guolin Ke's avatar
Guolin Ke committed
634
635
636
      return a.first < b.first;
    });
  } else {
637
    std::stable_sort(arr.begin(), arr.end(), [](const std::pair<T1, T2>& a, const std::pair<T1, T2>& b) {
Guolin Ke's avatar
Guolin Ke committed
638
639
640
641
642
643
644
645
646
      return a.first > b.first;
    });
  }
  for (size_t i = start; i < arr.size(); ++i) {
    keys[i] = arr[i].first;
    values[i] = arr[i].second;
  }
}

647
template <typename T>
Guolin Ke's avatar
Guolin Ke committed
648
649
inline static std::vector<T*> Vector2Ptr(std::vector<std::vector<T>>& data) {
  std::vector<T*> ptr(data.size());
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
  for (size_t i = 0; i < data.size(); ++i) {
    ptr[i] = data[i].data();
  }
  return ptr;
}

template <typename T>
inline static std::vector<int> VectorSize(const std::vector<std::vector<T>>& data) {
  std::vector<int> ret(data.size());
  for (size_t i = 0; i < data.size(); ++i) {
    ret[i] = static_cast<int>(data[i].size());
  }
  return ret;
}

Guolin Ke's avatar
Guolin Ke committed
665
inline static double AvoidInf(double x) {
Guolin Ke's avatar
Guolin Ke committed
666
667
  if (x >= 1e300) {
    return 1e300;
668
  } else if (x <= -1e300) {
Guolin Ke's avatar
Guolin Ke committed
669
    return -1e300;
Guolin Ke's avatar
Guolin Ke committed
670
671
672
673
674
  } else {
    return x;
  }
}

675
inline static float AvoidInf(float x) {
676
677
678
679
680
681
682
  if (x >= 1e38) {
    return 1e38f;
  } else if (x <= -1e38) {
    return -1e38f;
  } else {
    return x;
  }
683
684
685
}

template<typename _Iter> inline
686
687
688
689
static typename std::iterator_traits<_Iter>::value_type* IteratorValType(_Iter) {
  return (0);
}

690
template<typename _RanIt, typename _Pr, typename _VTRanIt> inline
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
static void ParallelSort(_RanIt _First, _RanIt _Last, _Pr _Pred, _VTRanIt*) {
  size_t len = _Last - _First;
  const size_t kMinInnerLen = 1024;
  int num_threads = 1;
  #pragma omp parallel
  #pragma omp master
  {
    num_threads = omp_get_num_threads();
  }
  if (len <= kMinInnerLen || num_threads <= 1) {
    std::sort(_First, _Last, _Pred);
    return;
  }
  size_t inner_size = (len + num_threads - 1) / num_threads;
  inner_size = std::max(inner_size, kMinInnerLen);
  num_threads = static_cast<int>((len + inner_size - 1) / inner_size);
707
  #pragma omp parallel for schedule(static, 1)
708
709
710
711
712
713
714
715
716
717
718
  for (int i = 0; i < num_threads; ++i) {
    size_t left = inner_size*i;
    size_t right = left + inner_size;
    right = std::min(right, len);
    if (right > left) {
      std::sort(_First + left, _First + right, _Pred);
    }
  }
  // Buffer for merge.
  std::vector<_VTRanIt> temp_buf(len);
  _RanIt buf = temp_buf.begin();
719
  size_t s = inner_size;
720
721
722
  // Recursive merge
  while (s < len) {
    int loop_size = static_cast<int>((len + s * 2 - 1) / (s * 2));
723
    #pragma omp parallel for schedule(static, 1)
724
725
726
727
728
    for (int i = 0; i < loop_size; ++i) {
      size_t left = i * 2 * s;
      size_t mid = left + s;
      size_t right = mid + s;
      right = std::min(len, right);
Guolin Ke's avatar
Guolin Ke committed
729
      if (mid >= right) { continue; }
730
731
732
733
734
735
736
      std::copy(_First + left, _First + mid, buf + left);
      std::merge(buf + left, buf + mid, _First + mid, _First + right, _First + left, _Pred);
    }
    s *= 2;
  }
}

737
template<typename _RanIt, typename _Pr> inline
738
739
740
741
static void ParallelSort(_RanIt _First, _RanIt _Last, _Pr _Pred) {
  return ParallelSort(_First, _Last, _Pred, IteratorValType(_First));
}

742
// Check that all y[] are in interval [ymin, ymax] (end points included); throws error if not
743
template <typename T>
744
inline static void CheckElementsIntervalClosed(const T *y, T ymin, T ymax, int ny, const char *callername) {
745
  auto fatal_msg = [&y, &ymin, &ymax, &callername](int i) {
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
    std::ostringstream os;
    os << "[%s]: does not tolerate element [#%i = " << y[i] << "] outside [" << ymin << ", " << ymax << "]";
    Log::Fatal(os.str().c_str(), callername, i);
  };
  for (int i = 1; i < ny; i += 2) {
    if (y[i - 1] < y[i]) {
      if (y[i - 1] < ymin) {
        fatal_msg(i - 1);
      } else if (y[i] > ymax) {
        fatal_msg(i);
      }
    } else {
      if (y[i - 1] > ymax) {
        fatal_msg(i - 1);
      } else if (y[i] < ymin) {
        fatal_msg(i);
      }
    }
  }
765
  if (ny & 1) {  // odd
766
767
    if (y[ny - 1] < ymin || y[ny - 1] > ymax) {
      fatal_msg(ny - 1);
768
769
770
771
772
773
    }
  }
}

// One-pass scan over array w with nw elements: find min, max and sum of elements;
// this is useful for checking weight requirements.
774
template <typename T1, typename T2>
775
inline static void ObtainMinMaxSum(const T1 *w, int nw, T1 *mi, T1 *ma, T2 *su) {
776
777
778
779
  T1 minw;
  T1 maxw;
  T1 sumw;
  int i;
780
  if (nw & 1) {  // odd
781
782
783
784
    minw = w[0];
    maxw = w[0];
    sumw = w[0];
    i = 2;
785
  } else {  // even
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
    if (w[0] < w[1]) {
      minw = w[0];
      maxw = w[1];
    } else {
      minw = w[1];
      maxw = w[0];
    }
    sumw = w[0] + w[1];
    i = 3;
  }
  for (; i < nw; i += 2) {
    if (w[i - 1] < w[i]) {
      minw = std::min(minw, w[i - 1]);
      maxw = std::max(maxw, w[i]);
    } else {
      minw = std::min(minw, w[i]);
      maxw = std::max(maxw, w[i - 1]);
    }
    sumw += w[i - 1] + w[i];
  }
  if (mi != nullptr) {
    *mi = minw;
  }
  if (ma != nullptr) {
    *ma = maxw;
  }
  if (su != nullptr) {
    *su = static_cast<T2>(sumw);
  }
815
816
}

817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
inline static std::vector<uint32_t> EmptyBitset(int n){
  int size = n / 32;
  if(n % 32 != 0) size++;
  return std::vector<uint32_t>(size);
}

template<typename T>
inline static void InsertBitset(std::vector<uint32_t>& vec, const T val){
    int i1 = val / 32;
    int i2 = val % 32;
    if (static_cast<int>(vec.size()) < i1 + 1) {
      vec.resize(i1 + 1, 0);
    }
    vec[i1] |= (1 << i2);  
}

833
834
template<typename T>
inline static std::vector<uint32_t> ConstructBitset(const T* vals, int n) {
835
836
837
838
839
840
841
842
843
844
845
846
  std::vector<uint32_t> ret;
  for (int i = 0; i < n; ++i) {
    int i1 = vals[i] / 32;
    int i2 = vals[i] % 32;
    if (static_cast<int>(ret.size()) < i1 + 1) {
      ret.resize(i1 + 1, 0);
    }
    ret[i1] |= (1 << i2);
  }
  return ret;
}

847
848
template<typename T>
inline static bool FindInBitset(const uint32_t* bits, int n, T pos) {
849
850
851
852
853
854
855
856
  int i1 = pos / 32;
  if (i1 >= n) {
    return false;
  }
  int i2 = pos % 32;
  return (bits[i1] >> i2) & 1;
}

857
858
859
860
861
862
863
864
865
inline static bool CheckDoubleEqualOrdered(double a, double b) {
  double upper = std::nextafter(a, INFINITY);
  return b <= upper;
}

inline static double GetDoubleUpperBound(double a) {
  return std::nextafter(a, INFINITY);;
}

866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
inline static size_t GetLine(const char* str) {
  auto start = str;
  while (*str != '\0' && *str != '\n' && *str != '\r') {
    ++str;
  }
  return str - start;
}

inline static const char* SkipNewLine(const char* str) {
  if (*str == '\r') {
    ++str;
  }
  if (*str == '\n') {
    ++str;
  }
  return str;
}

884
885
886
887
888
template <typename T>
static int Sign(T x) {
  return (x > T(0)) - (x < T(0));
}

Guolin Ke's avatar
Guolin Ke committed
889
890
891
892
893
894
895
896
897
template <typename T>
static T SafeLog(T x) {
  if (x > 0) {
    return std::log(x);
  } else {
    return -INFINITY;
  }
}

Guolin Ke's avatar
Guolin Ke committed
898
899
900
901
}  // namespace Common

}  // namespace LightGBM

Guolin Ke's avatar
Guolin Ke committed
902
#endif   // LightGBM_UTILS_COMMON_FUN_H_