Commit 87fa8b54 authored by Tony-Y's avatar Tony-Y Committed by Guolin Ke
Browse files

Change functions in common.h into template functions (#969) (#973)

* Fix coding style (#969)

Function names must be in the "Pascal Case" style.

* check_elements_interval_closed to CheckElementsIntervalClosed

* obtain_min_max_sum to ObtainMinMaxSum

* Change functions in common.h into template functions (#969)

* CheckElementsIntervalClosed

* ObtainMinMaxSum

These two functions were changed into template functions.

* Remove an unpreferable overload

* remove an overload of the function ObtainMinMaxSum

* Use stringstream to format T type
parent 6d34fb86
...@@ -580,20 +580,24 @@ static void ParallelSort(_RanIt _First, _RanIt _Last, _Pr _Pred) { ...@@ -580,20 +580,24 @@ static void ParallelSort(_RanIt _First, _RanIt _Last, _Pr _Pred) {
} }
// Check that all y[] are in interval [ymin, ymax] (end points included); throws error if not // Check that all y[] are in interval [ymin, ymax] (end points included); throws error if not
inline void CheckElementsIntervalClosed(const float *y, float ymin, float ymax, int ny, const char *callername) { template <typename T>
inline void CheckElementsIntervalClosed(const T *y, T ymin, T ymax, int ny, const char *callername) {
for (int i = 0; i < ny; ++i) { for (int i = 0; i < ny; ++i) {
if (y[i] < ymin || y[i] > ymax) { if (y[i] < ymin || y[i] > ymax) {
Log::Fatal("[%s]: does not tolerate element [#%i = %f] outside [%f, %f]", callername, i, y[i], ymin, ymax); std::ostringstream os;
os << "[%s]: does not tolerate element [#%i = " << y[i] << "] outside [" << ymin << ", " << ymax << "]";
Log::Fatal(os.str().c_str(), callername, i);
} }
} }
} }
// One-pass scan over array w with nw elements: find min, max and sum of elements; // One-pass scan over array w with nw elements: find min, max and sum of elements;
// this is useful for checking weight requirements. // this is useful for checking weight requirements.
inline void ObtainMinMaxSum(const float *w, int nw, float *mi, float *ma, double *su) { template <typename T1, typename T2>
float minw = w[0]; inline void ObtainMinMaxSum(const T1 *w, int nw, T1 *mi, T1 *ma, T2 *su) {
float maxw = w[0]; T1 minw = w[0];
double sumw = static_cast<double>(w[0]); T1 maxw = w[0];
T2 sumw = static_cast<T2>(w[0]);
for (int i = 1; i < nw; ++i) { for (int i = 1; i < nw; ++i) {
sumw += w[i]; sumw += w[i];
if (w[i] < minw) minw = w[i]; if (w[i] < minw) minw = w[i];
......
...@@ -87,7 +87,7 @@ public: ...@@ -87,7 +87,7 @@ public:
sum_weights_ = static_cast<double>(num_data_); sum_weights_ = static_cast<double>(num_data_);
} else { } else {
float minw; float minw;
Common::ObtainMinMaxSum(weights_, num_data_, &minw, nullptr, &sum_weights_); Common::ObtainMinMaxSum(weights_, num_data_, &minw, (float*)nullptr, &sum_weights_);
if (minw < 0.0f) { if (minw < 0.0f) {
Log::Fatal("[%s:%s]: (metric) weights not allowed to be negative", GetName()[0].c_str(), __func__); Log::Fatal("[%s:%s]: (metric) weights not allowed to be negative", GetName()[0].c_str(), __func__);
} }
...@@ -178,7 +178,7 @@ public: ...@@ -178,7 +178,7 @@ public:
// check all weights are strictly positive; throw error if not // check all weights are strictly positive; throw error if not
if (weights_ != nullptr) { if (weights_ != nullptr) {
float minw; float minw;
Common::ObtainMinMaxSum(weights_, num_data_, &minw, nullptr, nullptr); Common::ObtainMinMaxSum(weights_, num_data_, &minw, (float*)nullptr, (float*)nullptr);
if (minw <= 0.0f) { if (minw <= 0.0f) {
Log::Fatal("[%s:%s]: (metric) all weights must be positive", GetName()[0].c_str(), __func__); Log::Fatal("[%s:%s]: (metric) all weights must be positive", GetName()[0].c_str(), __func__);
} }
...@@ -263,7 +263,7 @@ public: ...@@ -263,7 +263,7 @@ public:
sum_weights_ = static_cast<double>(num_data_); sum_weights_ = static_cast<double>(num_data_);
} else { } else {
float minw; float minw;
Common::ObtainMinMaxSum(weights_, num_data_, &minw, nullptr, &sum_weights_); Common::ObtainMinMaxSum(weights_, num_data_, &minw, (float*)nullptr, &sum_weights_);
if (minw < 0.0f) { if (minw < 0.0f) {
Log::Fatal("[%s:%s]: (metric) at least one weight is negative", GetName()[0].c_str(), __func__); Log::Fatal("[%s:%s]: (metric) at least one weight is negative", GetName()[0].c_str(), __func__);
} }
......
...@@ -318,7 +318,7 @@ public: ...@@ -318,7 +318,7 @@ public:
// Safety check of labels // Safety check of labels
float miny; float miny;
double sumy; double sumy;
Common::ObtainMinMaxSum(label_, num_data_, &miny, nullptr, &sumy); Common::ObtainMinMaxSum(label_, num_data_, &miny, (float*)nullptr, &sumy);
if (miny < 0.0f) { if (miny < 0.0f) {
Log::Fatal("[%s]: at least one target label is negative.", GetName()); Log::Fatal("[%s]: at least one target label is negative.", GetName());
} }
......
...@@ -58,7 +58,7 @@ public: ...@@ -58,7 +58,7 @@ public:
if (weights_ != nullptr) { if (weights_ != nullptr) {
float minw; float minw;
double sumw; double sumw;
Common::ObtainMinMaxSum(weights_, num_data_, &minw, nullptr, &sumw); Common::ObtainMinMaxSum(weights_, num_data_, &minw, (float*)nullptr, &sumw);
if (minw < 0.0f) { if (minw < 0.0f) {
Log::Fatal("[%s]: at least one weight is negative.", GetName()); Log::Fatal("[%s]: at least one weight is negative.", GetName());
} }
...@@ -163,7 +163,7 @@ public: ...@@ -163,7 +163,7 @@ public:
if (weights_ != nullptr) { if (weights_ != nullptr) {
Common::ObtainMinMaxSum(weights_, num_data_, &min_weight_, &max_weight_, nullptr); Common::ObtainMinMaxSum(weights_, num_data_, &min_weight_, &max_weight_, (float*)nullptr);
if (min_weight_ <= 0.0f) { if (min_weight_ <= 0.0f) {
Log::Fatal("[%s]: at least one weight is non-positive.", GetName()); Log::Fatal("[%s]: at least one weight is non-positive.", GetName());
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment