"src/git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "45c53f7834106568be2af45e493b40e8095d470e"
Commit 302f84bc authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

trivial refine CheckElementsIntervalClosed & ObtainMinMaxSum (#1049)

* refine common.h

* fix typo

* specify captured variables
parent bd5e5e3e
......@@ -596,11 +596,29 @@ static void ParallelSort(_RanIt _First, _RanIt _Last, _Pr _Pred) {
// Check that all y[] are in interval [ymin, ymax] (end points included); throws error if not
template <typename T>
inline void CheckElementsIntervalClosed(const T *y, T ymin, T ymax, int ny, const char *callername) {
for (int i = 0; i < ny; ++i) {
if (y[i] < ymin || y[i] > ymax) {
std::ostringstream os;
os << "[%s]: does not tolerate element [#%i = " << y[i] << "] outside [" << ymin << ", " << ymax << "]";
Log::Fatal(os.str().c_str(), callername, i);
auto fatal_msg = [&y, &ymin, &ymax, &callername](int i) {
std::ostringstream os;
os << "[%s]: does not tolerate element [#%i = " << y[i] << "] outside [" << ymin << ", " << ymax << "]";
Log::Fatal(os.str().c_str(), callername, i);
};
for (int i = 1; i < ny; i += 2) {
if (y[i - 1] < y[i]) {
if (y[i - 1] < ymin) {
fatal_msg(i - 1);
} else if (y[i] > ymax) {
fatal_msg(i);
}
} else {
if (y[i - 1] > ymax) {
fatal_msg(i - 1);
} else if (y[i] < ymin) {
fatal_msg(i);
}
}
}
if (ny & 1) { // odd
if (y[ny - 1] < ymin || y[ny - 1] > ymax) {
fatal_msg(ny - 1);
}
}
}
......@@ -609,17 +627,45 @@ inline void CheckElementsIntervalClosed(const T *y, T ymin, T ymax, int ny, cons
// this is useful for checking weight requirements.
template <typename T1, typename T2>
inline void ObtainMinMaxSum(const T1 *w, int nw, T1 *mi, T1 *ma, T2 *su) {
T1 minw = w[0];
T1 maxw = w[0];
T2 sumw = static_cast<T2>(w[0]);
for (int i = 1; i < nw; ++i) {
sumw += w[i];
if (w[i] < minw) minw = w[i];
if (w[i] > maxw) maxw = w[i];
}
if (mi != nullptr) *mi = minw;
if (ma != nullptr) *ma = maxw;
if (su != nullptr) *su = sumw;
T1 minw;
T1 maxw;
T1 sumw;
int i;
if (nw & 1) { // odd
minw = w[0];
maxw = w[0];
sumw = w[0];
i = 2;
} else { // even
if (w[0] < w[1]) {
minw = w[0];
maxw = w[1];
} else {
minw = w[1];
maxw = w[0];
}
sumw = w[0] + w[1];
i = 3;
}
for (; i < nw; i += 2) {
if (w[i - 1] < w[i]) {
minw = std::min(minw, w[i - 1]);
maxw = std::max(maxw, w[i]);
} else {
minw = std::min(minw, w[i]);
maxw = std::max(maxw, w[i - 1]);
}
sumw += w[i - 1] + w[i];
}
if (mi != nullptr) {
*mi = minw;
}
if (ma != nullptr) {
*ma = maxw;
}
if (su != nullptr) {
*su = static_cast<T2>(sumw);
}
}
template<class T>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment