common.h 23.7 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
#ifndef LIGHTGBM_UTILS_COMMON_FUN_H_
#define LIGHTGBM_UTILS_COMMON_FUN_H_

#include <LightGBM/utils/log.h>
9
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
10

11
#include <limits>
Guolin Ke's avatar
Guolin Ke committed
12
#include <string>
13
#include <algorithm>
14
#include <cmath>
15
16
#include <cstdint>
#include <cstdio>
Guolin Ke's avatar
Guolin Ke committed
17
#include <functional>
18
#include <iomanip>
19
#include <iterator>
20
21
#include <memory>
#include <sstream>
Guolin Ke's avatar
Guolin Ke committed
22
#include <type_traits>
23
24
#include <utility>
#include <vector>
Guolin Ke's avatar
Guolin Ke committed
25

26
27
28
29
#ifdef _MSC_VER
#include "intrin.h"
#endif

Guolin Ke's avatar
Guolin Ke committed
30
31
32
33
namespace LightGBM {

namespace Common {

34
inline static char tolower(char in) {
Guolin Ke's avatar
Guolin Ke committed
35
36
37
38
39
  if (in <= 'Z' && in >= 'A')
    return in - ('Z' - 'z');
  return in;
}

40
inline static std::string Trim(std::string str) {
Guolin Ke's avatar
Guolin Ke committed
41
  if (str.empty()) {
Guolin Ke's avatar
Guolin Ke committed
42
43
44
45
46
47
48
    return str;
  }
  str.erase(str.find_last_not_of(" \f\n\r\t\v") + 1);
  str.erase(0, str.find_first_not_of(" \f\n\r\t\v"));
  return str;
}

49
inline static std::string RemoveQuotationSymbol(std::string str) {
Guolin Ke's avatar
Guolin Ke committed
50
  if (str.empty()) {
51
52
53
54
55
56
    return str;
  }
  str.erase(str.find_last_not_of("'\"") + 1);
  str.erase(0, str.find_first_not_of("'\""));
  return str;
}
Guolin Ke's avatar
Guolin Ke committed
57

Guolin Ke's avatar
Guolin Ke committed
58
59
60
61
62
63
64
inline static bool StartsWith(const std::string& str, const std::string prefix) {
  if (str.substr(0, prefix.size()) == prefix) {
    return true;
  } else {
    return false;
  }
}
Guolin Ke's avatar
Guolin Ke committed
65

Guolin Ke's avatar
Guolin Ke committed
66
inline static std::vector<std::string> Split(const char* c_str, char delimiter) {
Guolin Ke's avatar
Guolin Ke committed
67
  std::vector<std::string> ret;
Guolin Ke's avatar
Guolin Ke committed
68
69
  std::string str(c_str);
  size_t i = 0;
Guolin Ke's avatar
Guolin Ke committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
  size_t pos = 0;
  while (pos < str.length()) {
    if (str[pos] == delimiter) {
      if (i < pos) {
        ret.push_back(str.substr(i, pos - i));
      }
      ++pos;
      i = pos;
    } else {
      ++pos;
    }
  }
  if (i < pos) {
    ret.push_back(str.substr(i));
  }
  return ret;
}

inline static std::vector<std::string> SplitLines(const char* c_str) {
  std::vector<std::string> ret;
  std::string str(c_str);
  size_t i = 0;
  size_t pos = 0;
  while (pos < str.length()) {
    if (str[pos] == '\n' || str[pos] == '\r') {
      if (i < pos) {
        ret.push_back(str.substr(i, pos - i));
      }
      // skip the line endings
      while (str[pos] == '\n' || str[pos] == '\r') ++pos;
      // new begin
      i = pos;
    } else {
      ++pos;
    }
  }
  if (i < pos) {
    ret.push_back(str.substr(i));
Guolin Ke's avatar
Guolin Ke committed
108
109
110
111
  }
  return ret;
}

Guolin Ke's avatar
Guolin Ke committed
112
113
114
115
inline static std::vector<std::string> Split(const char* c_str, const char* delimiters) {
  std::vector<std::string> ret;
  std::string str(c_str);
  size_t i = 0;
Guolin Ke's avatar
Guolin Ke committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
  size_t pos = 0;
  while (pos < str.length()) {
    bool met_delimiters = false;
    for (int j = 0; delimiters[j] != '\0'; ++j) {
      if (str[pos] == delimiters[j]) {
        met_delimiters = true;
        break;
      }
    }
    if (met_delimiters) {
      if (i < pos) {
        ret.push_back(str.substr(i, pos - i));
      }
      ++pos;
      i = pos;
    } else {
      ++pos;
    }
  }
  if (i < pos) {
    ret.push_back(str.substr(i));
Guolin Ke's avatar
Guolin Ke committed
137
138
139
140
  }
  return ret;
}

141
142
143
144
template<typename T>
inline static const char* Atoi(const char* p, T* out) {
  int sign;
  T value;
Guolin Ke's avatar
Guolin Ke committed
145
146
147
148
149
150
151
  while (*p == ' ') {
    ++p;
  }
  sign = 1;
  if (*p == '-') {
    sign = -1;
    ++p;
152
  } else if (*p == '+') {
Guolin Ke's avatar
Guolin Ke committed
153
154
155
156
157
    ++p;
  }
  for (value = 0; *p >= '0' && *p <= '9'; ++p) {
    value = value * 10 + (*p - '0');
  }
158
  *out = static_cast<T>(sign * value);
Guolin Ke's avatar
Guolin Ke committed
159
160
161
162
163
164
  while (*p == ' ') {
    ++p;
  }
  return p;
}

165
template<typename T>
166
167
168
169
170
171
172
173
174
175
176
177
178
179
inline static double Pow(T base, int power) {
  if (power < 0) {
    return 1.0 / Pow(base, -power);
  } else if (power == 0) {
    return 1;
  } else if (power % 2 == 0) {
    return Pow(base*base, power / 2);
  } else if (power % 3 == 0) {
    return Pow(base*base*base, power / 3);
  } else {
    return base * Pow(base, power - 1);
  }
}

180
inline static const char* Atof(const char* p, double* out) {
Guolin Ke's avatar
Guolin Ke committed
181
  int frac;
182
  double sign, value, scale;
Guolin Ke's avatar
Guolin Ke committed
183
  *out = NAN;
Guolin Ke's avatar
Guolin Ke committed
184
185
186
187
188
  // Skip leading white space, if any.
  while (*p == ' ') {
    ++p;
  }
  // Get sign, if any.
189
  sign = 1.0;
Guolin Ke's avatar
Guolin Ke committed
190
  if (*p == '-') {
191
    sign = -1.0;
Guolin Ke's avatar
Guolin Ke committed
192
    ++p;
193
  } else if (*p == '+') {
Guolin Ke's avatar
Guolin Ke committed
194
195
196
    ++p;
  }

Guolin Ke's avatar
Guolin Ke committed
197
198
199
  // is a number
  if ((*p >= '0' && *p <= '9') || *p == '.' || *p == 'e' || *p == 'E') {
    // Get digits before decimal point or exponent, if any.
200
201
    for (value = 0.0; *p >= '0' && *p <= '9'; ++p) {
      value = value * 10.0 + (*p - '0');
Guolin Ke's avatar
Guolin Ke committed
202
    }
Guolin Ke's avatar
Guolin Ke committed
203

Guolin Ke's avatar
Guolin Ke committed
204
205
    // Get digits after decimal point, if any.
    if (*p == '.') {
206
207
      double right = 0.0;
      int nn = 0;
Guolin Ke's avatar
Guolin Ke committed
208
      ++p;
Guolin Ke's avatar
Guolin Ke committed
209
      while (*p >= '0' && *p <= '9') {
210
211
        right = (*p - '0') + right * 10.0;
        ++nn;
Guolin Ke's avatar
Guolin Ke committed
212
213
        ++p;
      }
214
      value += right / Pow(10.0, nn);
Guolin Ke's avatar
Guolin Ke committed
215
216
    }

Guolin Ke's avatar
Guolin Ke committed
217
218
    // Handle exponent, if any.
    frac = 0;
219
    scale = 1.0;
Guolin Ke's avatar
Guolin Ke committed
220
    if ((*p == 'e') || (*p == 'E')) {
Guolin Ke's avatar
Guolin Ke committed
221
      uint32_t expon;
Guolin Ke's avatar
Guolin Ke committed
222
      // Get sign of exponent, if any.
Guolin Ke's avatar
Guolin Ke committed
223
      ++p;
Guolin Ke's avatar
Guolin Ke committed
224
225
226
227
228
229
230
231
232
233
      if (*p == '-') {
        frac = 1;
        ++p;
      } else if (*p == '+') {
        ++p;
      }
      // Get digits of exponent, if any.
      for (expon = 0; *p >= '0' && *p <= '9'; ++p) {
        expon = expon * 10 + (*p - '0');
      }
234
235
236
      if (expon > 308) expon = 308;
      // Calculate scaling factor.
      while (expon >= 50) { scale *= 1E50; expon -= 50; }
Guolin Ke's avatar
Guolin Ke committed
237
      while (expon >= 8) { scale *= 1E8;  expon -= 8; }
238
      while (expon > 0) { scale *= 10.0; expon -= 1; }
Guolin Ke's avatar
Guolin Ke committed
239
    }
Guolin Ke's avatar
Guolin Ke committed
240
241
242
    // Return signed and scaled floating point result.
    *out = sign * (frac ? (value / scale) : (value * scale));
  } else {
243
    size_t cnt = 0;
244
    while (*(p + cnt) != '\0' && *(p + cnt) != ' '
245
246
247
           && *(p + cnt) != '\t' && *(p + cnt) != ','
           && *(p + cnt) != '\n' && *(p + cnt) != '\r'
           && *(p + cnt) != ':') {
248
249
      ++cnt;
    }
250
    if (cnt > 0) {
Guolin Ke's avatar
Guolin Ke committed
251
      std::string tmp_str(p, cnt);
Guolin Ke's avatar
Guolin Ke committed
252
      std::transform(tmp_str.begin(), tmp_str.end(), tmp_str.begin(), Common::tolower);
zhangjin's avatar
zhangjin committed
253
254
      if (tmp_str == std::string("na") || tmp_str == std::string("nan") ||
          tmp_str == std::string("null")) {
Guolin Ke's avatar
Guolin Ke committed
255
        *out = NAN;
256
      } else if (tmp_str == std::string("inf") || tmp_str == std::string("infinity")) {
257
        *out = sign * 1e308;
258
      } else {
259
        Log::Fatal("Unknown token %s in data file", tmp_str.c_str());
Guolin Ke's avatar
Guolin Ke committed
260
261
      }
      p += cnt;
262
    }
Guolin Ke's avatar
Guolin Ke committed
263
  }
Guolin Ke's avatar
Guolin Ke committed
264

Guolin Ke's avatar
Guolin Ke committed
265
266
267
  while (*p == ' ') {
    ++p;
  }
Guolin Ke's avatar
Guolin Ke committed
268

Guolin Ke's avatar
Guolin Ke committed
269
270
271
  return p;
}

272
inline static bool AtoiAndCheck(const char* p, int* out) {
273
274
275
276
277
278
279
  const char* after = Atoi(p, out);
  if (*after != '\0') {
    return false;
  }
  return true;
}

280
inline static bool AtofAndCheck(const char* p, double* out) {
281
282
283
284
285
286
287
  const char* after = Atof(p, out);
  if (*after != '\0') {
    return false;
  }
  return true;
}

288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
inline static unsigned CountDecimalDigit32(uint32_t n) {
#if defined(_MSC_VER) || defined(__GNUC__)
  static const uint32_t powers_of_10[] = {
    0,
    10,
    100,
    1000,
    10000,
    100000,
    1000000,
    10000000,
    100000000,
    1000000000
  };
#ifdef _MSC_VER
  unsigned long i = 0;
  _BitScanReverse(&i, n | 1);
  uint32_t t = (i + 1) * 1233 >> 12;
#elif __GNUC__
  uint32_t t = (32 - __builtin_clz(n | 1)) * 1233 >> 12;
#endif
  return t - (n < powers_of_10[t]) + 1;
#else
  if (n < 10) return 1;
  if (n < 100) return 2;
  if (n < 1000) return 3;
  if (n < 10000) return 4;
  if (n < 100000) return 5;
  if (n < 1000000) return 6;
  if (n < 10000000) return 7;
  if (n < 100000000) return 8;
  if (n < 1000000000) return 9;
  return 10;
#endif
}

inline static void Uint32ToStr(uint32_t value, char* buffer) {
  const char kDigitsLut[200] = {
326
327
328
329
330
331
332
333
334
335
    '0', '0', '0', '1', '0', '2', '0', '3', '0', '4', '0', '5', '0', '6', '0', '7', '0', '8', '0', '9',
    '1', '0', '1', '1', '1', '2', '1', '3', '1', '4', '1', '5', '1', '6', '1', '7', '1', '8', '1', '9',
    '2', '0', '2', '1', '2', '2', '2', '3', '2', '4', '2', '5', '2', '6', '2', '7', '2', '8', '2', '9',
    '3', '0', '3', '1', '3', '2', '3', '3', '3', '4', '3', '5', '3', '6', '3', '7', '3', '8', '3', '9',
    '4', '0', '4', '1', '4', '2', '4', '3', '4', '4', '4', '5', '4', '6', '4', '7', '4', '8', '4', '9',
    '5', '0', '5', '1', '5', '2', '5', '3', '5', '4', '5', '5', '5', '6', '5', '7', '5', '8', '5', '9',
    '6', '0', '6', '1', '6', '2', '6', '3', '6', '4', '6', '5', '6', '6', '6', '7', '6', '8', '6', '9',
    '7', '0', '7', '1', '7', '2', '7', '3', '7', '4', '7', '5', '7', '6', '7', '7', '7', '8', '7', '9',
    '8', '0', '8', '1', '8', '2', '8', '3', '8', '4', '8', '5', '8', '6', '8', '7', '8', '8', '8', '9',
    '9', '0', '9', '1', '9', '2', '9', '3', '9', '4', '9', '5', '9', '6', '9', '7', '9', '8', '9', '9'
336
337
338
339
340
341
342
343
344
345
346
347
348
  };
  unsigned digit = CountDecimalDigit32(value);
  buffer += digit;
  *buffer = '\0';

  while (value >= 100) {
    const unsigned i = (value % 100) << 1;
    value /= 100;
    *--buffer = kDigitsLut[i + 1];
    *--buffer = kDigitsLut[i];
  }

  if (value < 10) {
349
    *--buffer = static_cast<char>(value) + '0';
350
  } else {
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
    const unsigned i = value << 1;
    *--buffer = kDigitsLut[i + 1];
    *--buffer = kDigitsLut[i];
  }
}

inline static void Int32ToStr(int32_t value, char* buffer) {
  uint32_t u = static_cast<uint32_t>(value);
  if (value < 0) {
    *buffer++ = '-';
    u = ~u + 1;
  }
  Uint32ToStr(u, buffer);
}

366
inline static void DoubleToStr(double value, char* buffer, size_t
367
368
369
370
371
372
373
374
375
376
377
                               #ifdef _MSC_VER
                               buffer_len
                               #endif
) {
  #ifdef _MSC_VER
  sprintf_s(buffer, buffer_len, "%.17g", value);
  #else
  sprintf(buffer, "%.17g", value);
  #endif
}

Guolin Ke's avatar
Guolin Ke committed
378
379
380
381
382
383
384
385
386
387
388
389
390
391
inline static const char* SkipSpaceAndTab(const char* p) {
  while (*p == ' ' || *p == '\t') {
    ++p;
  }
  return p;
}

inline static const char* SkipReturn(const char* p) {
  while (*p == '\n' || *p == '\r' || *p == ' ') {
    ++p;
  }
  return p;
}

Guolin Ke's avatar
Guolin Ke committed
392
393
template<typename T, typename T2>
inline static std::vector<T2> ArrayCast(const std::vector<T>& arr) {
394
  std::vector<T2> ret(arr.size());
Guolin Ke's avatar
Guolin Ke committed
395
  for (size_t i = 0; i < arr.size(); ++i) {
396
    ret[i] = static_cast<T2>(arr[i]);
Guolin Ke's avatar
Guolin Ke committed
397
  }
Guolin Ke's avatar
Guolin Ke committed
398
  return ret;
Guolin Ke's avatar
Guolin Ke committed
399
400
}

401
402
template<typename T, bool is_float, bool is_unsign>
struct __TToStringHelperFast {
403
  void operator()(T value, char* buffer, size_t) const {
404
405
406
407
408
409
    Int32ToStr(value, buffer);
  }
};

template<typename T>
struct __TToStringHelperFast<T, true, false> {
410
  void operator()(T value, char* buffer, size_t
411
412
413
414
415
416
417
418
419
420
421
422
423
424
                  #ifdef _MSC_VER
                  buf_len
                  #endif
                  ) const {
    #ifdef _MSC_VER
    sprintf_s(buffer, buf_len, "%g", value);
    #else
    sprintf(buffer, "%g", value);
    #endif
  }
};

template<typename T>
struct __TToStringHelperFast<T, false, true> {
425
  void operator()(T value, char* buffer, size_t) const {
426
427
428
429
    Uint32ToStr(value, buffer);
  }
};

430
template<typename T>
431
432
inline static std::string ArrayToStringFast(const std::vector<T>& arr, size_t n) {
  if (arr.empty() || n == 0) {
433
    return std::string("");
Guolin Ke's avatar
Guolin Ke committed
434
  }
435
436
437
  __TToStringHelperFast<T, std::is_floating_point<T>::value, std::is_unsigned<T>::value> helper;
  const size_t buf_len = 16;
  std::vector<char> buffer(buf_len);
438
  std::stringstream str_buf;
439
440
441
442
443
  helper(arr[0], buffer.data(), buf_len);
  str_buf << buffer.data();
  for (size_t i = 1; i < std::min(n, arr.size()); ++i) {
    helper(arr[i], buffer.data(), buf_len);
    str_buf << ' ' << buffer.data();
Guolin Ke's avatar
Guolin Ke committed
444
  }
445
  return str_buf.str();
Guolin Ke's avatar
Guolin Ke committed
446
447
}

448
inline static std::string ArrayToString(const std::vector<double>& arr, size_t n) {
Guolin Ke's avatar
Guolin Ke committed
449
450
451
  if (arr.empty() || n == 0) {
    return std::string("");
  }
452
453
  const size_t buf_len = 32;
  std::vector<char> buffer(buf_len);
Guolin Ke's avatar
Guolin Ke committed
454
  std::stringstream str_buf;
455
456
  DoubleToStr(arr[0], buffer.data(), buf_len);
  str_buf << buffer.data();
Guolin Ke's avatar
Guolin Ke committed
457
  for (size_t i = 1; i < std::min(n, arr.size()); ++i) {
458
459
    DoubleToStr(arr[i], buffer.data(), buf_len);
    str_buf << ' ' << buffer.data();
Guolin Ke's avatar
Guolin Ke committed
460
461
462
463
  }
  return str_buf.str();
}

464
465
466
template<typename T, bool is_float>
struct __StringToTHelper {
  T operator()(const std::string& str) const {
467
468
469
    T ret = 0;
    Atoi(str.c_str(), &ret);
    return ret;
470
471
472
473
474
475
476
477
478
479
  }
};

template<typename T>
struct __StringToTHelper<T, true> {
  T operator()(const std::string& str) const {
    return static_cast<T>(std::stod(str));
  }
};

Guolin Ke's avatar
Guolin Ke committed
480
template<typename T>
481
inline static std::vector<T> StringToArray(const std::string& str, char delimiter) {
Guolin Ke's avatar
Guolin Ke committed
482
  std::vector<std::string> strs = Split(str.c_str(), delimiter);
483
484
  std::vector<T> ret;
  ret.reserve(strs.size());
485
  __StringToTHelper<T, std::is_floating_point<T>::value> helper;
486
487
  for (const auto& s : strs) {
    ret.push_back(helper(s));
Guolin Ke's avatar
Guolin Ke committed
488
489
490
491
  }
  return ret;
}

Guolin Ke's avatar
Guolin Ke committed
492
template<typename T>
493
494
495
496
497
498
inline static std::vector<T> StringToArray(const std::string& str, int n) {
  if (n == 0) {
    return std::vector<T>();
  }
  std::vector<std::string> strs = Split(str.c_str(), ' ');
  CHECK(strs.size() == static_cast<size_t>(n));
Guolin Ke's avatar
Guolin Ke committed
499
  std::vector<T> ret;
500
501
502
503
  ret.reserve(strs.size());
  __StringToTHelper<T, std::is_floating_point<T>::value> helper;
  for (const auto& s : strs) {
    ret.push_back(helper(s));
Guolin Ke's avatar
Guolin Ke committed
504
505
506
507
  }
  return ret;
}

508
509
510
511
512
513
514
515
516
517
518
519
template<typename T, bool is_float>
struct __StringToTHelperFast {
  const char* operator()(const char*p, T* out) const {
    return Atoi(p, out);
  }
};

template<typename T>
struct __StringToTHelperFast<T, true> {
  const char* operator()(const char*p, T* out) const {
    double tmp = 0.0f;
    auto ret = Atof(p, &tmp);
520
    *out = static_cast<T>(tmp);
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
    return ret;
  }
};

template<typename T>
inline static std::vector<T> StringToArrayFast(const std::string& str, int n) {
  if (n == 0) {
    return std::vector<T>();
  }
  auto p_str = str.c_str();
  __StringToTHelperFast<T, std::is_floating_point<T>::value> helper;
  std::vector<T> ret(n);
  for (int i = 0; i < n; ++i) {
    p_str = helper(p_str, &ret[i]);
  }
  return ret;
}

539
template<typename T>
Guolin Ke's avatar
Guolin Ke committed
540
inline static std::string Join(const std::vector<T>& strs, const char* delimiter) {
Guolin Ke's avatar
Guolin Ke committed
541
  if (strs.empty()) {
Guolin Ke's avatar
Guolin Ke committed
542
543
    return std::string("");
  }
544
  std::stringstream str_buf;
545
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
546
  str_buf << strs[0];
Guolin Ke's avatar
Guolin Ke committed
547
  for (size_t i = 1; i < strs.size(); ++i) {
548
549
    str_buf << delimiter;
    str_buf << strs[i];
Guolin Ke's avatar
Guolin Ke committed
550
  }
551
  return str_buf.str();
Guolin Ke's avatar
Guolin Ke committed
552
553
}

554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
template<>
inline std::string Join<int8_t>(const std::vector<int8_t>& strs, const char* delimiter) {
  if (strs.empty()) {
    return std::string("");
  }
  std::stringstream str_buf;
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
  str_buf << static_cast<int16_t>(strs[0]);
  for (size_t i = 1; i < strs.size(); ++i) {
    str_buf << delimiter;
    str_buf << static_cast<int16_t>(strs[i]);
  }
  return str_buf.str();
}

569
template<typename T>
Guolin Ke's avatar
Guolin Ke committed
570
inline static std::string Join(const std::vector<T>& strs, size_t start, size_t end, const char* delimiter) {
Guolin Ke's avatar
Guolin Ke committed
571
572
573
  if (end - start <= 0) {
    return std::string("");
  }
Guolin Ke's avatar
Guolin Ke committed
574
575
  start = std::min(start, static_cast<size_t>(strs.size()) - 1);
  end = std::min(end, static_cast<size_t>(strs.size()));
576
  std::stringstream str_buf;
577
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
578
  str_buf << strs[start];
Guolin Ke's avatar
Guolin Ke committed
579
  for (size_t i = start + 1; i < end; ++i) {
580
581
    str_buf << delimiter;
    str_buf << strs[i];
Guolin Ke's avatar
Guolin Ke committed
582
  }
583
  return str_buf.str();
Guolin Ke's avatar
Guolin Ke committed
584
585
}

586
inline static int64_t Pow2RoundUp(int64_t x) {
Guolin Ke's avatar
Guolin Ke committed
587
588
589
590
591
592
593
594
595
596
  int64_t t = 1;
  for (int i = 0; i < 64; ++i) {
    if (t >= x) {
      return t;
    }
    t <<= 1;
  }
  return 0;
}

597
/*!
598
 * \brief Do inplace softmax transformation on p_rec
599
600
 * \param p_rec The input/output vector of the values.
 */
601
inline static void Softmax(std::vector<double>* p_rec) {
602
603
  std::vector<double> &rec = *p_rec;
  double wmax = rec[0];
604
605
606
  for (size_t i = 1; i < rec.size(); ++i) {
    wmax = std::max(rec[i], wmax);
  }
607
  double wsum = 0.0f;
608
609
610
611
612
  for (size_t i = 0; i < rec.size(); ++i) {
    rec[i] = std::exp(rec[i] - wmax);
    wsum += rec[i];
  }
  for (size_t i = 0; i < rec.size(); ++i) {
613
    rec[i] /= static_cast<double>(wsum);
614
615
616
  }
}

617
inline static void Softmax(const double* input, double* output, int len) {
Guolin Ke's avatar
Guolin Ke committed
618
  double wmax = input[0];
619
  for (int i = 1; i < len; ++i) {
Guolin Ke's avatar
Guolin Ke committed
620
    wmax = std::max(input[i], wmax);
621
622
623
  }
  double wsum = 0.0f;
  for (int i = 0; i < len; ++i) {
Guolin Ke's avatar
Guolin Ke committed
624
625
    output[i] = std::exp(input[i] - wmax);
    wsum += output[i];
626
627
  }
  for (int i = 0; i < len; ++i) {
Guolin Ke's avatar
Guolin Ke committed
628
    output[i] /= static_cast<double>(wsum);
629
630
631
  }
}

Guolin Ke's avatar
Guolin Ke committed
632
633
634
template<typename T>
std::vector<const T*> ConstPtrInVectorWrapper(const std::vector<std::unique_ptr<T>>& input) {
  std::vector<const T*> ret;
Guolin Ke's avatar
Guolin Ke committed
635
636
  for (auto t = input.begin(); t !=input.end(); ++t) {
    ret.push_back(t->get());
637
  }
Guolin Ke's avatar
Guolin Ke committed
638
  return ret;
639
640
}

Guolin Ke's avatar
Guolin Ke committed
641
template<typename T1, typename T2>
Guolin Ke's avatar
Guolin Ke committed
642
inline static void SortForPair(std::vector<T1>* keys, std::vector<T2>* values, size_t start, bool is_reverse = false) {
Guolin Ke's avatar
Guolin Ke committed
643
  std::vector<std::pair<T1, T2>> arr;
Guolin Ke's avatar
Guolin Ke committed
644
645
  auto& ref_key = *keys;
  auto& ref_value = *values;
Guolin Ke's avatar
Guolin Ke committed
646
  for (size_t i = start; i < keys->size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
647
    arr.emplace_back(ref_key[i], ref_value[i]);
Guolin Ke's avatar
Guolin Ke committed
648
649
  }
  if (!is_reverse) {
650
    std::stable_sort(arr.begin(), arr.end(), [](const std::pair<T1, T2>& a, const std::pair<T1, T2>& b) {
Guolin Ke's avatar
Guolin Ke committed
651
652
653
      return a.first < b.first;
    });
  } else {
654
    std::stable_sort(arr.begin(), arr.end(), [](const std::pair<T1, T2>& a, const std::pair<T1, T2>& b) {
Guolin Ke's avatar
Guolin Ke committed
655
656
657
658
      return a.first > b.first;
    });
  }
  for (size_t i = start; i < arr.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
659
660
    ref_key[i] = arr[i].first;
    ref_value[i] = arr[i].second;
Guolin Ke's avatar
Guolin Ke committed
661
662
663
  }
}

664
template <typename T>
Guolin Ke's avatar
Guolin Ke committed
665
666
inline static std::vector<T*> Vector2Ptr(std::vector<std::vector<T>>* data) {
  std::vector<T*> ptr(data->size());
Guolin Ke's avatar
Guolin Ke committed
667
  auto& ref_data = *data;
Guolin Ke's avatar
Guolin Ke committed
668
  for (size_t i = 0; i < data->size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
669
    ptr[i] = ref_data[i].data();
670
671
672
673
674
675
676
677
678
679
680
681
682
  }
  return ptr;
}

template <typename T>
inline static std::vector<int> VectorSize(const std::vector<std::vector<T>>& data) {
  std::vector<int> ret(data.size());
  for (size_t i = 0; i < data.size(); ++i) {
    ret[i] = static_cast<int>(data[i].size());
  }
  return ret;
}

Guolin Ke's avatar
Guolin Ke committed
683
inline static double AvoidInf(double x) {
Guolin Ke's avatar
Guolin Ke committed
684
685
686
  if (std::isnan(x)) {
    return 0.0;
  } else if (x >= 1e300) {
Guolin Ke's avatar
Guolin Ke committed
687
    return 1e300;
688
  } else if (x <= -1e300) {
Guolin Ke's avatar
Guolin Ke committed
689
    return -1e300;
Guolin Ke's avatar
Guolin Ke committed
690
691
692
693
694
  } else {
    return x;
  }
}

695
inline static float AvoidInf(float x) {
Guolin Ke's avatar
Guolin Ke committed
696
  if (std::isnan(x)) {
Guolin Ke's avatar
Guolin Ke committed
697
698
    return 0.0f;
  } else if (x >= 1e38) {
699
700
701
702
703
704
    return 1e38f;
  } else if (x <= -1e38) {
    return -1e38f;
  } else {
    return x;
  }
705
706
707
}

template<typename _Iter> inline
708
709
710
711
static typename std::iterator_traits<_Iter>::value_type* IteratorValType(_Iter) {
  return (0);
}

712
template<typename _RanIt, typename _Pr, typename _VTRanIt> inline
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
static void ParallelSort(_RanIt _First, _RanIt _Last, _Pr _Pred, _VTRanIt*) {
  size_t len = _Last - _First;
  const size_t kMinInnerLen = 1024;
  int num_threads = 1;
  #pragma omp parallel
  #pragma omp master
  {
    num_threads = omp_get_num_threads();
  }
  if (len <= kMinInnerLen || num_threads <= 1) {
    std::sort(_First, _Last, _Pred);
    return;
  }
  size_t inner_size = (len + num_threads - 1) / num_threads;
  inner_size = std::max(inner_size, kMinInnerLen);
  num_threads = static_cast<int>((len + inner_size - 1) / inner_size);
729
  #pragma omp parallel for schedule(static, 1)
730
731
732
733
734
735
736
737
738
739
740
  for (int i = 0; i < num_threads; ++i) {
    size_t left = inner_size*i;
    size_t right = left + inner_size;
    right = std::min(right, len);
    if (right > left) {
      std::sort(_First + left, _First + right, _Pred);
    }
  }
  // Buffer for merge.
  std::vector<_VTRanIt> temp_buf(len);
  _RanIt buf = temp_buf.begin();
741
  size_t s = inner_size;
742
743
744
  // Recursive merge
  while (s < len) {
    int loop_size = static_cast<int>((len + s * 2 - 1) / (s * 2));
745
    #pragma omp parallel for schedule(static, 1)
746
747
748
749
750
    for (int i = 0; i < loop_size; ++i) {
      size_t left = i * 2 * s;
      size_t mid = left + s;
      size_t right = mid + s;
      right = std::min(len, right);
Guolin Ke's avatar
Guolin Ke committed
751
      if (mid >= right) { continue; }
752
753
754
755
756
757
758
      std::copy(_First + left, _First + mid, buf + left);
      std::merge(buf + left, buf + mid, _First + mid, _First + right, _First + left, _Pred);
    }
    s *= 2;
  }
}

759
template<typename _RanIt, typename _Pr> inline
760
761
762
763
static void ParallelSort(_RanIt _First, _RanIt _Last, _Pr _Pred) {
  return ParallelSort(_First, _Last, _Pred, IteratorValType(_First));
}

764
// Check that all y[] are in interval [ymin, ymax] (end points included); throws error if not
765
template <typename T>
766
inline static void CheckElementsIntervalClosed(const T *y, T ymin, T ymax, int ny, const char *callername) {
767
  auto fatal_msg = [&y, &ymin, &ymax, &callername](int i) {
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
    std::ostringstream os;
    os << "[%s]: does not tolerate element [#%i = " << y[i] << "] outside [" << ymin << ", " << ymax << "]";
    Log::Fatal(os.str().c_str(), callername, i);
  };
  for (int i = 1; i < ny; i += 2) {
    if (y[i - 1] < y[i]) {
      if (y[i - 1] < ymin) {
        fatal_msg(i - 1);
      } else if (y[i] > ymax) {
        fatal_msg(i);
      }
    } else {
      if (y[i - 1] > ymax) {
        fatal_msg(i - 1);
      } else if (y[i] < ymin) {
        fatal_msg(i);
      }
    }
  }
787
  if (ny & 1) {  // odd
788
789
    if (y[ny - 1] < ymin || y[ny - 1] > ymax) {
      fatal_msg(ny - 1);
790
791
792
793
794
795
    }
  }
}

// One-pass scan over array w with nw elements: find min, max and sum of elements;
// this is useful for checking weight requirements.
796
template <typename T1, typename T2>
797
inline static void ObtainMinMaxSum(const T1 *w, int nw, T1 *mi, T1 *ma, T2 *su) {
798
799
800
801
  T1 minw;
  T1 maxw;
  T1 sumw;
  int i;
802
  if (nw & 1) {  // odd
803
804
805
806
    minw = w[0];
    maxw = w[0];
    sumw = w[0];
    i = 2;
807
  } else {  // even
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
    if (w[0] < w[1]) {
      minw = w[0];
      maxw = w[1];
    } else {
      minw = w[1];
      maxw = w[0];
    }
    sumw = w[0] + w[1];
    i = 3;
  }
  for (; i < nw; i += 2) {
    if (w[i - 1] < w[i]) {
      minw = std::min(minw, w[i - 1]);
      maxw = std::max(maxw, w[i]);
    } else {
      minw = std::min(minw, w[i]);
      maxw = std::max(maxw, w[i - 1]);
    }
    sumw += w[i - 1] + w[i];
  }
  if (mi != nullptr) {
    *mi = minw;
  }
  if (ma != nullptr) {
    *ma = maxw;
  }
  if (su != nullptr) {
    *su = static_cast<T2>(sumw);
  }
837
838
}

839
inline static std::vector<uint32_t> EmptyBitset(int n) {
840
  int size = n / 32;
841
  if (n % 32 != 0) ++size;
842
843
844
845
  return std::vector<uint32_t>(size);
}

template<typename T>
Guolin Ke's avatar
Guolin Ke committed
846
inline static void InsertBitset(std::vector<uint32_t>* vec, const T val) {
Guolin Ke's avatar
Guolin Ke committed
847
848
849
850
851
852
853
  auto& ref_v = *vec;
  int i1 = val / 32;
  int i2 = val % 32;
  if (static_cast<int>(vec->size()) < i1 + 1) {
    vec->resize(i1 + 1, 0);
  }
  ref_v[i1] |= (1 << i2);
854
855
}

856
857
template<typename T>
inline static std::vector<uint32_t> ConstructBitset(const T* vals, int n) {
858
859
860
861
862
863
864
865
866
867
868
869
  std::vector<uint32_t> ret;
  for (int i = 0; i < n; ++i) {
    int i1 = vals[i] / 32;
    int i2 = vals[i] % 32;
    if (static_cast<int>(ret.size()) < i1 + 1) {
      ret.resize(i1 + 1, 0);
    }
    ret[i1] |= (1 << i2);
  }
  return ret;
}

870
871
template<typename T>
inline static bool FindInBitset(const uint32_t* bits, int n, T pos) {
872
873
874
875
876
877
878
879
  int i1 = pos / 32;
  if (i1 >= n) {
    return false;
  }
  int i2 = pos % 32;
  return (bits[i1] >> i2) & 1;
}

880
881
882
883
884
885
886
887
888
inline static bool CheckDoubleEqualOrdered(double a, double b) {
  double upper = std::nextafter(a, INFINITY);
  return b <= upper;
}

inline static double GetDoubleUpperBound(double a) {
  return std::nextafter(a, INFINITY);;
}

889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
inline static size_t GetLine(const char* str) {
  auto start = str;
  while (*str != '\0' && *str != '\n' && *str != '\r') {
    ++str;
  }
  return str - start;
}

inline static const char* SkipNewLine(const char* str) {
  if (*str == '\r') {
    ++str;
  }
  if (*str == '\n') {
    ++str;
  }
  return str;
}

907
908
909
910
911
template <typename T>
static int Sign(T x) {
  return (x > T(0)) - (x < T(0));
}

Guolin Ke's avatar
Guolin Ke committed
912
913
914
915
916
917
918
919
920
template <typename T>
static T SafeLog(T x) {
  if (x > 0) {
    return std::log(x);
  } else {
    return -INFINITY;
  }
}

921
922
923
924
925
926
927
928
929
inline bool CheckASCII(const std::string& s) {
  for (auto c : s) {
    if (static_cast<unsigned char>(c) > 127) {
      return false;
    }
  }
  return true;
}

930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
inline bool CheckAllowedJSON(const std::string& s) {
  unsigned char char_code;
  for (auto c : s) {
    char_code = static_cast<unsigned char>(c);
    if (char_code == 34      // "
        || char_code == 44   // ,
        || char_code == 58   // :
        || char_code == 91   // [
        || char_code == 93   // ]
        || char_code == 123  // {
        || char_code == 125  // }
        ) {
      return false;
    }
  }
  return true;
}

Guolin Ke's avatar
Guolin Ke committed
948
949
950
951
}  // namespace Common

}  // namespace LightGBM

Guolin Ke's avatar
Guolin Ke committed
952
#endif   // LightGBM_UTILS_COMMON_FUN_H_