common.h 23.7 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
7
8
#ifndef LIGHTGBM_UTILS_COMMON_FUN_H_
#define LIGHTGBM_UTILS_COMMON_FUN_H_

#include <LightGBM/utils/log.h>
9
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
10

11
#include <limits>
Guolin Ke's avatar
Guolin Ke committed
12
#include <string>
13
#include <algorithm>
14
#include <cmath>
15
16
#include <cstdint>
#include <cstdio>
Guolin Ke's avatar
Guolin Ke committed
17
#include <functional>
18
#include <iomanip>
19
#include <iterator>
20
21
#include <memory>
#include <sstream>
Guolin Ke's avatar
Guolin Ke committed
22
#include <type_traits>
23
24
#include <utility>
#include <vector>
Guolin Ke's avatar
Guolin Ke committed
25

26
27
28
29
#ifdef _MSC_VER
#include "intrin.h"
#endif

Guolin Ke's avatar
Guolin Ke committed
30
31
32
33
namespace LightGBM {

namespace Common {

34
inline static char tolower(char in) {
Guolin Ke's avatar
Guolin Ke committed
35
36
37
38
39
  if (in <= 'Z' && in >= 'A')
    return in - ('Z' - 'z');
  return in;
}

40
inline static std::string Trim(std::string str) {
Guolin Ke's avatar
Guolin Ke committed
41
  if (str.empty()) {
Guolin Ke's avatar
Guolin Ke committed
42
43
44
45
46
47
48
    return str;
  }
  str.erase(str.find_last_not_of(" \f\n\r\t\v") + 1);
  str.erase(0, str.find_first_not_of(" \f\n\r\t\v"));
  return str;
}

49
inline static std::string RemoveQuotationSymbol(std::string str) {
Guolin Ke's avatar
Guolin Ke committed
50
  if (str.empty()) {
51
52
53
54
55
56
    return str;
  }
  str.erase(str.find_last_not_of("'\"") + 1);
  str.erase(0, str.find_first_not_of("'\""));
  return str;
}
Guolin Ke's avatar
Guolin Ke committed
57

Guolin Ke's avatar
Guolin Ke committed
58
59
60
61
62
63
64
inline static bool StartsWith(const std::string& str, const std::string prefix) {
  if (str.substr(0, prefix.size()) == prefix) {
    return true;
  } else {
    return false;
  }
}
Guolin Ke's avatar
Guolin Ke committed
65

Guolin Ke's avatar
Guolin Ke committed
66
inline static std::vector<std::string> Split(const char* c_str, char delimiter) {
Guolin Ke's avatar
Guolin Ke committed
67
  std::vector<std::string> ret;
Guolin Ke's avatar
Guolin Ke committed
68
69
  std::string str(c_str);
  size_t i = 0;
Guolin Ke's avatar
Guolin Ke committed
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
  size_t pos = 0;
  while (pos < str.length()) {
    if (str[pos] == delimiter) {
      if (i < pos) {
        ret.push_back(str.substr(i, pos - i));
      }
      ++pos;
      i = pos;
    } else {
      ++pos;
    }
  }
  if (i < pos) {
    ret.push_back(str.substr(i));
  }
  return ret;
}

inline static std::vector<std::string> SplitLines(const char* c_str) {
  std::vector<std::string> ret;
  std::string str(c_str);
  size_t i = 0;
  size_t pos = 0;
  while (pos < str.length()) {
    if (str[pos] == '\n' || str[pos] == '\r') {
      if (i < pos) {
        ret.push_back(str.substr(i, pos - i));
      }
      // skip the line endings
      while (str[pos] == '\n' || str[pos] == '\r') ++pos;
      // new begin
      i = pos;
    } else {
      ++pos;
    }
  }
  if (i < pos) {
    ret.push_back(str.substr(i));
Guolin Ke's avatar
Guolin Ke committed
108
109
110
111
  }
  return ret;
}

Guolin Ke's avatar
Guolin Ke committed
112
113
114
115
inline static std::vector<std::string> Split(const char* c_str, const char* delimiters) {
  std::vector<std::string> ret;
  std::string str(c_str);
  size_t i = 0;
Guolin Ke's avatar
Guolin Ke committed
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
  size_t pos = 0;
  while (pos < str.length()) {
    bool met_delimiters = false;
    for (int j = 0; delimiters[j] != '\0'; ++j) {
      if (str[pos] == delimiters[j]) {
        met_delimiters = true;
        break;
      }
    }
    if (met_delimiters) {
      if (i < pos) {
        ret.push_back(str.substr(i, pos - i));
      }
      ++pos;
      i = pos;
    } else {
      ++pos;
    }
  }
  if (i < pos) {
    ret.push_back(str.substr(i));
Guolin Ke's avatar
Guolin Ke committed
137
138
139
140
  }
  return ret;
}

141
142
143
144
template<typename T>
inline static const char* Atoi(const char* p, T* out) {
  int sign;
  T value;
Guolin Ke's avatar
Guolin Ke committed
145
146
147
148
149
150
151
  while (*p == ' ') {
    ++p;
  }
  sign = 1;
  if (*p == '-') {
    sign = -1;
    ++p;
152
  } else if (*p == '+') {
Guolin Ke's avatar
Guolin Ke committed
153
154
155
156
157
    ++p;
  }
  for (value = 0; *p >= '0' && *p <= '9'; ++p) {
    value = value * 10 + (*p - '0');
  }
158
  *out = static_cast<T>(sign * value);
Guolin Ke's avatar
Guolin Ke committed
159
160
161
162
163
164
  while (*p == ' ') {
    ++p;
  }
  return p;
}

165
template<typename T>
166
167
168
169
170
171
172
173
174
175
176
177
178
179
inline static double Pow(T base, int power) {
  if (power < 0) {
    return 1.0 / Pow(base, -power);
  } else if (power == 0) {
    return 1;
  } else if (power % 2 == 0) {
    return Pow(base*base, power / 2);
  } else if (power % 3 == 0) {
    return Pow(base*base*base, power / 3);
  } else {
    return base * Pow(base, power - 1);
  }
}

180
inline static const char* Atof(const char* p, double* out) {
Guolin Ke's avatar
Guolin Ke committed
181
  int frac;
182
  double sign, value, scale;
Guolin Ke's avatar
Guolin Ke committed
183
  *out = NAN;
Guolin Ke's avatar
Guolin Ke committed
184
185
186
187
188
  // Skip leading white space, if any.
  while (*p == ' ') {
    ++p;
  }
  // Get sign, if any.
189
  sign = 1.0;
Guolin Ke's avatar
Guolin Ke committed
190
  if (*p == '-') {
191
    sign = -1.0;
Guolin Ke's avatar
Guolin Ke committed
192
    ++p;
193
  } else if (*p == '+') {
Guolin Ke's avatar
Guolin Ke committed
194
195
196
    ++p;
  }

Guolin Ke's avatar
Guolin Ke committed
197
198
199
  // is a number
  if ((*p >= '0' && *p <= '9') || *p == '.' || *p == 'e' || *p == 'E') {
    // Get digits before decimal point or exponent, if any.
200
201
    for (value = 0.0; *p >= '0' && *p <= '9'; ++p) {
      value = value * 10.0 + (*p - '0');
Guolin Ke's avatar
Guolin Ke committed
202
    }
Guolin Ke's avatar
Guolin Ke committed
203

Guolin Ke's avatar
Guolin Ke committed
204
205
    // Get digits after decimal point, if any.
    if (*p == '.') {
206
207
      double right = 0.0;
      int nn = 0;
Guolin Ke's avatar
Guolin Ke committed
208
      ++p;
Guolin Ke's avatar
Guolin Ke committed
209
      while (*p >= '0' && *p <= '9') {
210
211
        right = (*p - '0') + right * 10.0;
        ++nn;
Guolin Ke's avatar
Guolin Ke committed
212
213
        ++p;
      }
214
      value += right / Pow(10.0, nn);
Guolin Ke's avatar
Guolin Ke committed
215
216
    }

Guolin Ke's avatar
Guolin Ke committed
217
218
    // Handle exponent, if any.
    frac = 0;
219
    scale = 1.0;
Guolin Ke's avatar
Guolin Ke committed
220
    if ((*p == 'e') || (*p == 'E')) {
Guolin Ke's avatar
Guolin Ke committed
221
      uint32_t expon;
Guolin Ke's avatar
Guolin Ke committed
222
      // Get sign of exponent, if any.
Guolin Ke's avatar
Guolin Ke committed
223
      ++p;
Guolin Ke's avatar
Guolin Ke committed
224
225
226
227
228
229
230
231
232
233
      if (*p == '-') {
        frac = 1;
        ++p;
      } else if (*p == '+') {
        ++p;
      }
      // Get digits of exponent, if any.
      for (expon = 0; *p >= '0' && *p <= '9'; ++p) {
        expon = expon * 10 + (*p - '0');
      }
234
235
236
      if (expon > 308) expon = 308;
      // Calculate scaling factor.
      while (expon >= 50) { scale *= 1E50; expon -= 50; }
Guolin Ke's avatar
Guolin Ke committed
237
      while (expon >= 8) { scale *= 1E8;  expon -= 8; }
238
      while (expon > 0) { scale *= 10.0; expon -= 1; }
Guolin Ke's avatar
Guolin Ke committed
239
    }
Guolin Ke's avatar
Guolin Ke committed
240
241
242
    // Return signed and scaled floating point result.
    *out = sign * (frac ? (value / scale) : (value * scale));
  } else {
243
    size_t cnt = 0;
244
    while (*(p + cnt) != '\0' && *(p + cnt) != ' '
245
246
247
           && *(p + cnt) != '\t' && *(p + cnt) != ','
           && *(p + cnt) != '\n' && *(p + cnt) != '\r'
           && *(p + cnt) != ':') {
248
249
      ++cnt;
    }
250
    if (cnt > 0) {
Guolin Ke's avatar
Guolin Ke committed
251
      std::string tmp_str(p, cnt);
Guolin Ke's avatar
Guolin Ke committed
252
      std::transform(tmp_str.begin(), tmp_str.end(), tmp_str.begin(), Common::tolower);
zhangjin's avatar
zhangjin committed
253
254
      if (tmp_str == std::string("na") || tmp_str == std::string("nan") ||
          tmp_str == std::string("null")) {
Guolin Ke's avatar
Guolin Ke committed
255
        *out = NAN;
256
      } else if (tmp_str == std::string("inf") || tmp_str == std::string("infinity")) {
257
        *out = sign * 1e308;
258
      } else {
259
        Log::Fatal("Unknown token %s in data file", tmp_str.c_str());
Guolin Ke's avatar
Guolin Ke committed
260
261
      }
      p += cnt;
262
    }
Guolin Ke's avatar
Guolin Ke committed
263
  }
Guolin Ke's avatar
Guolin Ke committed
264

Guolin Ke's avatar
Guolin Ke committed
265
266
267
  while (*p == ' ') {
    ++p;
  }
Guolin Ke's avatar
Guolin Ke committed
268

Guolin Ke's avatar
Guolin Ke committed
269
270
271
  return p;
}

272
inline static bool AtoiAndCheck(const char* p, int* out) {
273
274
275
276
277
278
279
  const char* after = Atoi(p, out);
  if (*after != '\0') {
    return false;
  }
  return true;
}

280
inline static bool AtofAndCheck(const char* p, double* out) {
281
282
283
284
285
286
287
  const char* after = Atof(p, out);
  if (*after != '\0') {
    return false;
  }
  return true;
}

288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
inline static unsigned CountDecimalDigit32(uint32_t n) {
#if defined(_MSC_VER) || defined(__GNUC__)
  static const uint32_t powers_of_10[] = {
    0,
    10,
    100,
    1000,
    10000,
    100000,
    1000000,
    10000000,
    100000000,
    1000000000
  };
#ifdef _MSC_VER
  unsigned long i = 0;
  _BitScanReverse(&i, n | 1);
  uint32_t t = (i + 1) * 1233 >> 12;
#elif __GNUC__
  uint32_t t = (32 - __builtin_clz(n | 1)) * 1233 >> 12;
#endif
  return t - (n < powers_of_10[t]) + 1;
#else
  if (n < 10) return 1;
  if (n < 100) return 2;
  if (n < 1000) return 3;
  if (n < 10000) return 4;
  if (n < 100000) return 5;
  if (n < 1000000) return 6;
  if (n < 10000000) return 7;
  if (n < 100000000) return 8;
  if (n < 1000000000) return 9;
  return 10;
#endif
}

inline static void Uint32ToStr(uint32_t value, char* buffer) {
  const char kDigitsLut[200] = {
326
327
328
329
330
331
332
333
334
335
    '0', '0', '0', '1', '0', '2', '0', '3', '0', '4', '0', '5', '0', '6', '0', '7', '0', '8', '0', '9',
    '1', '0', '1', '1', '1', '2', '1', '3', '1', '4', '1', '5', '1', '6', '1', '7', '1', '8', '1', '9',
    '2', '0', '2', '1', '2', '2', '2', '3', '2', '4', '2', '5', '2', '6', '2', '7', '2', '8', '2', '9',
    '3', '0', '3', '1', '3', '2', '3', '3', '3', '4', '3', '5', '3', '6', '3', '7', '3', '8', '3', '9',
    '4', '0', '4', '1', '4', '2', '4', '3', '4', '4', '4', '5', '4', '6', '4', '7', '4', '8', '4', '9',
    '5', '0', '5', '1', '5', '2', '5', '3', '5', '4', '5', '5', '5', '6', '5', '7', '5', '8', '5', '9',
    '6', '0', '6', '1', '6', '2', '6', '3', '6', '4', '6', '5', '6', '6', '6', '7', '6', '8', '6', '9',
    '7', '0', '7', '1', '7', '2', '7', '3', '7', '4', '7', '5', '7', '6', '7', '7', '7', '8', '7', '9',
    '8', '0', '8', '1', '8', '2', '8', '3', '8', '4', '8', '5', '8', '6', '8', '7', '8', '8', '8', '9',
    '9', '0', '9', '1', '9', '2', '9', '3', '9', '4', '9', '5', '9', '6', '9', '7', '9', '8', '9', '9'
336
337
338
339
340
341
342
343
344
345
346
347
348
  };
  unsigned digit = CountDecimalDigit32(value);
  buffer += digit;
  *buffer = '\0';

  while (value >= 100) {
    const unsigned i = (value % 100) << 1;
    value /= 100;
    *--buffer = kDigitsLut[i + 1];
    *--buffer = kDigitsLut[i];
  }

  if (value < 10) {
349
    *--buffer = static_cast<char>(value) + '0';
350
  } else {
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
    const unsigned i = value << 1;
    *--buffer = kDigitsLut[i + 1];
    *--buffer = kDigitsLut[i];
  }
}

inline static void Int32ToStr(int32_t value, char* buffer) {
  uint32_t u = static_cast<uint32_t>(value);
  if (value < 0) {
    *buffer++ = '-';
    u = ~u + 1;
  }
  Uint32ToStr(u, buffer);
}

366
inline static void DoubleToStr(double value, char* buffer, size_t
367
368
369
370
371
372
373
374
375
376
377
                               #ifdef _MSC_VER
                               buffer_len
                               #endif
) {
  #ifdef _MSC_VER
  sprintf_s(buffer, buffer_len, "%.17g", value);
  #else
  sprintf(buffer, "%.17g", value);
  #endif
}

Guolin Ke's avatar
Guolin Ke committed
378
379
380
381
382
383
384
385
386
387
388
389
390
391
inline static const char* SkipSpaceAndTab(const char* p) {
  while (*p == ' ' || *p == '\t') {
    ++p;
  }
  return p;
}

inline static const char* SkipReturn(const char* p) {
  while (*p == '\n' || *p == '\r' || *p == ' ') {
    ++p;
  }
  return p;
}

Guolin Ke's avatar
Guolin Ke committed
392
393
template<typename T, typename T2>
inline static std::vector<T2> ArrayCast(const std::vector<T>& arr) {
394
  std::vector<T2> ret(arr.size());
Guolin Ke's avatar
Guolin Ke committed
395
  for (size_t i = 0; i < arr.size(); ++i) {
396
    ret[i] = static_cast<T2>(arr[i]);
Guolin Ke's avatar
Guolin Ke committed
397
  }
Guolin Ke's avatar
Guolin Ke committed
398
  return ret;
Guolin Ke's avatar
Guolin Ke committed
399
400
}

401
402
template<typename T, bool is_float, bool is_unsign>
struct __TToStringHelperFast {
403
  void operator()(T value, char* buffer, size_t) const {
404
405
406
407
408
409
    Int32ToStr(value, buffer);
  }
};

template<typename T>
struct __TToStringHelperFast<T, true, false> {
410
  void operator()(T value, char* buffer, size_t
411
412
413
                  #ifdef _MSC_VER
                  buf_len
                  #endif
414
415
                  )
  const {
416
417
418
419
420
421
422
423
424
425
    #ifdef _MSC_VER
    sprintf_s(buffer, buf_len, "%g", value);
    #else
    sprintf(buffer, "%g", value);
    #endif
  }
};

template<typename T>
struct __TToStringHelperFast<T, false, true> {
426
  void operator()(T value, char* buffer, size_t) const {
427
428
429
430
    Uint32ToStr(value, buffer);
  }
};

431
template<typename T>
432
433
inline static std::string ArrayToStringFast(const std::vector<T>& arr, size_t n) {
  if (arr.empty() || n == 0) {
434
    return std::string("");
Guolin Ke's avatar
Guolin Ke committed
435
  }
436
437
438
  __TToStringHelperFast<T, std::is_floating_point<T>::value, std::is_unsigned<T>::value> helper;
  const size_t buf_len = 16;
  std::vector<char> buffer(buf_len);
439
  std::stringstream str_buf;
440
441
442
443
444
  helper(arr[0], buffer.data(), buf_len);
  str_buf << buffer.data();
  for (size_t i = 1; i < std::min(n, arr.size()); ++i) {
    helper(arr[i], buffer.data(), buf_len);
    str_buf << ' ' << buffer.data();
Guolin Ke's avatar
Guolin Ke committed
445
  }
446
  return str_buf.str();
Guolin Ke's avatar
Guolin Ke committed
447
448
}

449
inline static std::string ArrayToString(const std::vector<double>& arr, size_t n) {
Guolin Ke's avatar
Guolin Ke committed
450
451
452
  if (arr.empty() || n == 0) {
    return std::string("");
  }
453
454
  const size_t buf_len = 32;
  std::vector<char> buffer(buf_len);
Guolin Ke's avatar
Guolin Ke committed
455
  std::stringstream str_buf;
456
457
  DoubleToStr(arr[0], buffer.data(), buf_len);
  str_buf << buffer.data();
Guolin Ke's avatar
Guolin Ke committed
458
  for (size_t i = 1; i < std::min(n, arr.size()); ++i) {
459
460
    DoubleToStr(arr[i], buffer.data(), buf_len);
    str_buf << ' ' << buffer.data();
Guolin Ke's avatar
Guolin Ke committed
461
462
463
464
  }
  return str_buf.str();
}

465
466
467
template<typename T, bool is_float>
struct __StringToTHelper {
  T operator()(const std::string& str) const {
468
469
470
    T ret = 0;
    Atoi(str.c_str(), &ret);
    return ret;
471
472
473
474
475
476
477
478
479
480
  }
};

template<typename T>
struct __StringToTHelper<T, true> {
  T operator()(const std::string& str) const {
    return static_cast<T>(std::stod(str));
  }
};

Guolin Ke's avatar
Guolin Ke committed
481
template<typename T>
482
inline static std::vector<T> StringToArray(const std::string& str, char delimiter) {
Guolin Ke's avatar
Guolin Ke committed
483
  std::vector<std::string> strs = Split(str.c_str(), delimiter);
484
485
  std::vector<T> ret;
  ret.reserve(strs.size());
486
  __StringToTHelper<T, std::is_floating_point<T>::value> helper;
487
488
  for (const auto& s : strs) {
    ret.push_back(helper(s));
Guolin Ke's avatar
Guolin Ke committed
489
490
491
492
  }
  return ret;
}

Guolin Ke's avatar
Guolin Ke committed
493
template<typename T>
494
495
496
497
498
499
inline static std::vector<T> StringToArray(const std::string& str, int n) {
  if (n == 0) {
    return std::vector<T>();
  }
  std::vector<std::string> strs = Split(str.c_str(), ' ');
  CHECK(strs.size() == static_cast<size_t>(n));
Guolin Ke's avatar
Guolin Ke committed
500
  std::vector<T> ret;
501
502
503
504
  ret.reserve(strs.size());
  __StringToTHelper<T, std::is_floating_point<T>::value> helper;
  for (const auto& s : strs) {
    ret.push_back(helper(s));
Guolin Ke's avatar
Guolin Ke committed
505
506
507
508
  }
  return ret;
}

509
510
511
512
513
514
515
516
517
518
519
520
template<typename T, bool is_float>
struct __StringToTHelperFast {
  const char* operator()(const char*p, T* out) const {
    return Atoi(p, out);
  }
};

template<typename T>
struct __StringToTHelperFast<T, true> {
  const char* operator()(const char*p, T* out) const {
    double tmp = 0.0f;
    auto ret = Atof(p, &tmp);
521
    *out = static_cast<T>(tmp);
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
    return ret;
  }
};

template<typename T>
inline static std::vector<T> StringToArrayFast(const std::string& str, int n) {
  if (n == 0) {
    return std::vector<T>();
  }
  auto p_str = str.c_str();
  __StringToTHelperFast<T, std::is_floating_point<T>::value> helper;
  std::vector<T> ret(n);
  for (int i = 0; i < n; ++i) {
    p_str = helper(p_str, &ret[i]);
  }
  return ret;
}

540
template<typename T>
Guolin Ke's avatar
Guolin Ke committed
541
inline static std::string Join(const std::vector<T>& strs, const char* delimiter) {
Guolin Ke's avatar
Guolin Ke committed
542
  if (strs.empty()) {
Guolin Ke's avatar
Guolin Ke committed
543
544
    return std::string("");
  }
545
  std::stringstream str_buf;
546
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
547
  str_buf << strs[0];
Guolin Ke's avatar
Guolin Ke committed
548
  for (size_t i = 1; i < strs.size(); ++i) {
549
550
    str_buf << delimiter;
    str_buf << strs[i];
Guolin Ke's avatar
Guolin Ke committed
551
  }
552
  return str_buf.str();
Guolin Ke's avatar
Guolin Ke committed
553
554
}

555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
template<>
inline std::string Join<int8_t>(const std::vector<int8_t>& strs, const char* delimiter) {
  if (strs.empty()) {
    return std::string("");
  }
  std::stringstream str_buf;
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
  str_buf << static_cast<int16_t>(strs[0]);
  for (size_t i = 1; i < strs.size(); ++i) {
    str_buf << delimiter;
    str_buf << static_cast<int16_t>(strs[i]);
  }
  return str_buf.str();
}

570
template<typename T>
Guolin Ke's avatar
Guolin Ke committed
571
inline static std::string Join(const std::vector<T>& strs, size_t start, size_t end, const char* delimiter) {
Guolin Ke's avatar
Guolin Ke committed
572
573
574
  if (end - start <= 0) {
    return std::string("");
  }
Guolin Ke's avatar
Guolin Ke committed
575
576
  start = std::min(start, static_cast<size_t>(strs.size()) - 1);
  end = std::min(end, static_cast<size_t>(strs.size()));
577
  std::stringstream str_buf;
578
  str_buf << std::setprecision(std::numeric_limits<double>::digits10 + 2);
579
  str_buf << strs[start];
Guolin Ke's avatar
Guolin Ke committed
580
  for (size_t i = start + 1; i < end; ++i) {
581
582
    str_buf << delimiter;
    str_buf << strs[i];
Guolin Ke's avatar
Guolin Ke committed
583
  }
584
  return str_buf.str();
Guolin Ke's avatar
Guolin Ke committed
585
586
}

587
inline static int64_t Pow2RoundUp(int64_t x) {
Guolin Ke's avatar
Guolin Ke committed
588
589
590
591
592
593
594
595
596
597
  int64_t t = 1;
  for (int i = 0; i < 64; ++i) {
    if (t >= x) {
      return t;
    }
    t <<= 1;
  }
  return 0;
}

598
/*!
599
 * \brief Do inplace softmax transformation on p_rec
600
601
 * \param p_rec The input/output vector of the values.
 */
602
inline static void Softmax(std::vector<double>* p_rec) {
603
604
  std::vector<double> &rec = *p_rec;
  double wmax = rec[0];
605
606
607
  for (size_t i = 1; i < rec.size(); ++i) {
    wmax = std::max(rec[i], wmax);
  }
608
  double wsum = 0.0f;
609
610
611
612
613
  for (size_t i = 0; i < rec.size(); ++i) {
    rec[i] = std::exp(rec[i] - wmax);
    wsum += rec[i];
  }
  for (size_t i = 0; i < rec.size(); ++i) {
614
    rec[i] /= static_cast<double>(wsum);
615
616
617
  }
}

618
inline static void Softmax(const double* input, double* output, int len) {
Guolin Ke's avatar
Guolin Ke committed
619
  double wmax = input[0];
620
  for (int i = 1; i < len; ++i) {
Guolin Ke's avatar
Guolin Ke committed
621
    wmax = std::max(input[i], wmax);
622
623
624
  }
  double wsum = 0.0f;
  for (int i = 0; i < len; ++i) {
Guolin Ke's avatar
Guolin Ke committed
625
626
    output[i] = std::exp(input[i] - wmax);
    wsum += output[i];
627
628
  }
  for (int i = 0; i < len; ++i) {
Guolin Ke's avatar
Guolin Ke committed
629
    output[i] /= static_cast<double>(wsum);
630
631
632
  }
}

Guolin Ke's avatar
Guolin Ke committed
633
634
635
template<typename T>
std::vector<const T*> ConstPtrInVectorWrapper(const std::vector<std::unique_ptr<T>>& input) {
  std::vector<const T*> ret;
Guolin Ke's avatar
Guolin Ke committed
636
637
  for (auto t = input.begin(); t !=input.end(); ++t) {
    ret.push_back(t->get());
638
  }
Guolin Ke's avatar
Guolin Ke committed
639
  return ret;
640
641
}

Guolin Ke's avatar
Guolin Ke committed
642
template<typename T1, typename T2>
Guolin Ke's avatar
Guolin Ke committed
643
inline static void SortForPair(std::vector<T1>* keys, std::vector<T2>* values, size_t start, bool is_reverse = false) {
Guolin Ke's avatar
Guolin Ke committed
644
  std::vector<std::pair<T1, T2>> arr;
Guolin Ke's avatar
Guolin Ke committed
645
646
  auto& ref_key = *keys;
  auto& ref_value = *values;
Guolin Ke's avatar
Guolin Ke committed
647
  for (size_t i = start; i < keys->size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
648
    arr.emplace_back(ref_key[i], ref_value[i]);
Guolin Ke's avatar
Guolin Ke committed
649
650
  }
  if (!is_reverse) {
651
    std::stable_sort(arr.begin(), arr.end(), [](const std::pair<T1, T2>& a, const std::pair<T1, T2>& b) {
Guolin Ke's avatar
Guolin Ke committed
652
653
654
      return a.first < b.first;
    });
  } else {
655
    std::stable_sort(arr.begin(), arr.end(), [](const std::pair<T1, T2>& a, const std::pair<T1, T2>& b) {
Guolin Ke's avatar
Guolin Ke committed
656
657
658
659
      return a.first > b.first;
    });
  }
  for (size_t i = start; i < arr.size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
660
661
    ref_key[i] = arr[i].first;
    ref_value[i] = arr[i].second;
Guolin Ke's avatar
Guolin Ke committed
662
663
664
  }
}

665
template <typename T>
Guolin Ke's avatar
Guolin Ke committed
666
667
inline static std::vector<T*> Vector2Ptr(std::vector<std::vector<T>>* data) {
  std::vector<T*> ptr(data->size());
Guolin Ke's avatar
Guolin Ke committed
668
  auto& ref_data = *data;
Guolin Ke's avatar
Guolin Ke committed
669
  for (size_t i = 0; i < data->size(); ++i) {
Guolin Ke's avatar
Guolin Ke committed
670
    ptr[i] = ref_data[i].data();
671
672
673
674
675
676
677
678
679
680
681
682
683
  }
  return ptr;
}

template <typename T>
inline static std::vector<int> VectorSize(const std::vector<std::vector<T>>& data) {
  std::vector<int> ret(data.size());
  for (size_t i = 0; i < data.size(); ++i) {
    ret[i] = static_cast<int>(data[i].size());
  }
  return ret;
}

Guolin Ke's avatar
Guolin Ke committed
684
inline static double AvoidInf(double x) {
Guolin Ke's avatar
Guolin Ke committed
685
686
687
  if (std::isnan(x)) {
    return 0.0;
  } else if (x >= 1e300) {
Guolin Ke's avatar
Guolin Ke committed
688
    return 1e300;
689
  } else if (x <= -1e300) {
Guolin Ke's avatar
Guolin Ke committed
690
    return -1e300;
Guolin Ke's avatar
Guolin Ke committed
691
692
693
694
695
  } else {
    return x;
  }
}

696
inline static float AvoidInf(float x) {
Guolin Ke's avatar
Guolin Ke committed
697
  if (std::isnan(x)) {
Guolin Ke's avatar
Guolin Ke committed
698
699
    return 0.0f;
  } else if (x >= 1e38) {
700
701
702
703
704
705
    return 1e38f;
  } else if (x <= -1e38) {
    return -1e38f;
  } else {
    return x;
  }
706
707
708
}

template<typename _Iter> inline
709
710
711
712
static typename std::iterator_traits<_Iter>::value_type* IteratorValType(_Iter) {
  return (0);
}

713
template<typename _RanIt, typename _Pr, typename _VTRanIt> inline
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
static void ParallelSort(_RanIt _First, _RanIt _Last, _Pr _Pred, _VTRanIt*) {
  size_t len = _Last - _First;
  const size_t kMinInnerLen = 1024;
  int num_threads = 1;
  #pragma omp parallel
  #pragma omp master
  {
    num_threads = omp_get_num_threads();
  }
  if (len <= kMinInnerLen || num_threads <= 1) {
    std::sort(_First, _Last, _Pred);
    return;
  }
  size_t inner_size = (len + num_threads - 1) / num_threads;
  inner_size = std::max(inner_size, kMinInnerLen);
  num_threads = static_cast<int>((len + inner_size - 1) / inner_size);
730
  #pragma omp parallel for schedule(static, 1)
731
732
733
734
735
736
737
738
739
740
741
  for (int i = 0; i < num_threads; ++i) {
    size_t left = inner_size*i;
    size_t right = left + inner_size;
    right = std::min(right, len);
    if (right > left) {
      std::sort(_First + left, _First + right, _Pred);
    }
  }
  // Buffer for merge.
  std::vector<_VTRanIt> temp_buf(len);
  _RanIt buf = temp_buf.begin();
742
  size_t s = inner_size;
743
744
745
  // Recursive merge
  while (s < len) {
    int loop_size = static_cast<int>((len + s * 2 - 1) / (s * 2));
746
    #pragma omp parallel for schedule(static, 1)
747
748
749
750
751
    for (int i = 0; i < loop_size; ++i) {
      size_t left = i * 2 * s;
      size_t mid = left + s;
      size_t right = mid + s;
      right = std::min(len, right);
Guolin Ke's avatar
Guolin Ke committed
752
      if (mid >= right) { continue; }
753
754
755
756
757
758
759
      std::copy(_First + left, _First + mid, buf + left);
      std::merge(buf + left, buf + mid, _First + mid, _First + right, _First + left, _Pred);
    }
    s *= 2;
  }
}

760
template<typename _RanIt, typename _Pr> inline
761
762
763
764
static void ParallelSort(_RanIt _First, _RanIt _Last, _Pr _Pred) {
  return ParallelSort(_First, _Last, _Pred, IteratorValType(_First));
}

765
// Check that all y[] are in interval [ymin, ymax] (end points included); throws error if not
766
template <typename T>
767
inline static void CheckElementsIntervalClosed(const T *y, T ymin, T ymax, int ny, const char *callername) {
768
  auto fatal_msg = [&y, &ymin, &ymax, &callername](int i) {
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
    std::ostringstream os;
    os << "[%s]: does not tolerate element [#%i = " << y[i] << "] outside [" << ymin << ", " << ymax << "]";
    Log::Fatal(os.str().c_str(), callername, i);
  };
  for (int i = 1; i < ny; i += 2) {
    if (y[i - 1] < y[i]) {
      if (y[i - 1] < ymin) {
        fatal_msg(i - 1);
      } else if (y[i] > ymax) {
        fatal_msg(i);
      }
    } else {
      if (y[i - 1] > ymax) {
        fatal_msg(i - 1);
      } else if (y[i] < ymin) {
        fatal_msg(i);
      }
    }
  }
788
  if (ny & 1) {  // odd
789
790
    if (y[ny - 1] < ymin || y[ny - 1] > ymax) {
      fatal_msg(ny - 1);
791
792
793
794
795
796
    }
  }
}

// One-pass scan over array w with nw elements: find min, max and sum of elements;
// this is useful for checking weight requirements.
797
template <typename T1, typename T2>
798
inline static void ObtainMinMaxSum(const T1 *w, int nw, T1 *mi, T1 *ma, T2 *su) {
799
800
801
802
  T1 minw;
  T1 maxw;
  T1 sumw;
  int i;
803
  if (nw & 1) {  // odd
804
805
806
807
    minw = w[0];
    maxw = w[0];
    sumw = w[0];
    i = 2;
808
  } else {  // even
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
    if (w[0] < w[1]) {
      minw = w[0];
      maxw = w[1];
    } else {
      minw = w[1];
      maxw = w[0];
    }
    sumw = w[0] + w[1];
    i = 3;
  }
  for (; i < nw; i += 2) {
    if (w[i - 1] < w[i]) {
      minw = std::min(minw, w[i - 1]);
      maxw = std::max(maxw, w[i]);
    } else {
      minw = std::min(minw, w[i]);
      maxw = std::max(maxw, w[i - 1]);
    }
    sumw += w[i - 1] + w[i];
  }
  if (mi != nullptr) {
    *mi = minw;
  }
  if (ma != nullptr) {
    *ma = maxw;
  }
  if (su != nullptr) {
    *su = static_cast<T2>(sumw);
  }
838
839
}

840
inline static std::vector<uint32_t> EmptyBitset(int n) {
841
  int size = n / 32;
842
  if (n % 32 != 0) ++size;
843
844
845
846
  return std::vector<uint32_t>(size);
}

template<typename T>
Guolin Ke's avatar
Guolin Ke committed
847
inline static void InsertBitset(std::vector<uint32_t>* vec, const T val) {
Guolin Ke's avatar
Guolin Ke committed
848
849
850
851
852
853
854
  auto& ref_v = *vec;
  int i1 = val / 32;
  int i2 = val % 32;
  if (static_cast<int>(vec->size()) < i1 + 1) {
    vec->resize(i1 + 1, 0);
  }
  ref_v[i1] |= (1 << i2);
855
856
}

857
858
template<typename T>
inline static std::vector<uint32_t> ConstructBitset(const T* vals, int n) {
859
860
861
862
863
864
865
866
867
868
869
870
  std::vector<uint32_t> ret;
  for (int i = 0; i < n; ++i) {
    int i1 = vals[i] / 32;
    int i2 = vals[i] % 32;
    if (static_cast<int>(ret.size()) < i1 + 1) {
      ret.resize(i1 + 1, 0);
    }
    ret[i1] |= (1 << i2);
  }
  return ret;
}

871
872
template<typename T>
inline static bool FindInBitset(const uint32_t* bits, int n, T pos) {
873
874
875
876
877
878
879
880
  int i1 = pos / 32;
  if (i1 >= n) {
    return false;
  }
  int i2 = pos % 32;
  return (bits[i1] >> i2) & 1;
}

881
882
883
884
885
886
887
888
889
inline static bool CheckDoubleEqualOrdered(double a, double b) {
  double upper = std::nextafter(a, INFINITY);
  return b <= upper;
}

inline static double GetDoubleUpperBound(double a) {
  return std::nextafter(a, INFINITY);;
}

890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
inline static size_t GetLine(const char* str) {
  auto start = str;
  while (*str != '\0' && *str != '\n' && *str != '\r') {
    ++str;
  }
  return str - start;
}

inline static const char* SkipNewLine(const char* str) {
  if (*str == '\r') {
    ++str;
  }
  if (*str == '\n') {
    ++str;
  }
  return str;
}

908
909
910
911
912
template <typename T>
static int Sign(T x) {
  return (x > T(0)) - (x < T(0));
}

Guolin Ke's avatar
Guolin Ke committed
913
914
915
916
917
918
919
920
921
template <typename T>
static T SafeLog(T x) {
  if (x > 0) {
    return std::log(x);
  } else {
    return -INFINITY;
  }
}

922
923
924
925
926
927
928
929
930
inline bool CheckASCII(const std::string& s) {
  for (auto c : s) {
    if (static_cast<unsigned char>(c) > 127) {
      return false;
    }
  }
  return true;
}

931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
inline bool CheckAllowedJSON(const std::string& s) {
  unsigned char char_code;
  for (auto c : s) {
    char_code = static_cast<unsigned char>(c);
    if (char_code == 34      // "
        || char_code == 44   // ,
        || char_code == 58   // :
        || char_code == 91   // [
        || char_code == 93   // ]
        || char_code == 123  // {
        || char_code == 125  // }
        ) {
      return false;
    }
  }
  return true;
}

Guolin Ke's avatar
Guolin Ke committed
949
950
951
952
}  // namespace Common

}  // namespace LightGBM

Guolin Ke's avatar
Guolin Ke committed
953
#endif   // LightGBM_UTILS_COMMON_FUN_H_