Commit cc98351f authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

speed up the AUC metric by parallel sort. (#556)

* speed up the AUC metric by parallel sort.

* fix unix compile.

* fix template in unix.

* remove warning.
parent 7fd0c7ed
......@@ -2,6 +2,7 @@
#define LIGHTGBM_UTILS_COMMON_FUN_H_
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
#include <cstdio>
#include <string>
......@@ -12,6 +13,7 @@
#include <cmath>
#include <functional>
#include <memory>
#include <iterator>
#include <type_traits>
#include <iomanip>
......@@ -472,6 +474,62 @@ inline static double AvoidInf(double x) {
}
}
template<class _Iter> inline
static typename std::iterator_traits<_Iter>::value_type* IteratorValType(_Iter) {
return (0);
}
template<class _RanIt, class _Pr, class _VTRanIt> inline
static void ParallelSort(_RanIt _First, _RanIt _Last, _Pr _Pred, _VTRanIt*) {
size_t len = _Last - _First;
const size_t kMinInnerLen = 1024;
int num_threads = 1;
#pragma omp parallel
#pragma omp master
{
num_threads = omp_get_num_threads();
}
if (len <= kMinInnerLen || num_threads <= 1) {
std::sort(_First, _Last, _Pred);
return;
}
size_t inner_size = (len + num_threads - 1) / num_threads;
inner_size = std::max(inner_size, kMinInnerLen);
num_threads = static_cast<int>((len + inner_size - 1) / inner_size);
#pragma omp parallel for schedule(static,1)
for (int i = 0; i < num_threads; ++i) {
size_t left = inner_size*i;
size_t right = left + inner_size;
right = std::min(right, len);
if (right > left) {
std::sort(_First + left, _First + right, _Pred);
}
}
// Buffer for merge.
std::vector<_VTRanIt> temp_buf(len);
_RanIt buf = temp_buf.begin();
int s = inner_size;
// Recursive merge
while (s < len) {
int loop_size = static_cast<int>((len + s * 2 - 1) / (s * 2));
#pragma omp parallel for schedule(static,1)
for (int i = 0; i < loop_size; ++i) {
size_t left = i * 2 * s;
size_t mid = left + s;
size_t right = mid + s;
right = std::min(len, right);
std::copy(_First + left, _First + mid, buf + left);
std::merge(buf + left, buf + mid, _First + mid, _First + right, _First + left, _Pred);
}
s *= 2;
}
}
template<class _RanIt, class _Pr> inline
static void ParallelSort(_RanIt _First, _RanIt _Last, _Pr _Pred) {
return ParallelSort(_First, _Last, _Pred, IteratorValType(_First));
}
} // namespace Common
} // namespace LightGBM
......
......@@ -2,6 +2,7 @@
#define LIGHTGBM_METRIC_BINARY_METRIC_HPP_
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/common.h>
#include <LightGBM/metric.h>
......@@ -195,7 +196,7 @@ public:
for (data_size_t i = 0; i < num_data_; ++i) {
sorted_idx.emplace_back(i);
}
std::sort(sorted_idx.begin(), sorted_idx.end(), [score](data_size_t a, data_size_t b) {return score[a] > score[b]; });
Common::ParallelSort(sorted_idx.begin(), sorted_idx.end(), [score](data_size_t a, data_size_t b) {return score[a] > score[b]; });
// temp sum of postive label
double cur_pos = 0.0f;
// total sum of postive label
......
......@@ -3,7 +3,7 @@
namespace LightGBM {
Linkers::Linkers(NetworkConfig config) {
Linkers::Linkers(NetworkConfig) {
int argc = 0;
char**argv = nullptr;
int flag = 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment