Unverified Commit b52ecb16 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

simplify and speed up comparisons for splits with identical gains (#4542)

* fix incorrect behavior of SplitInfo == operator for splits with identical gains

* LightSplitInfo too, and improve comment

* dont check features unnecessarily

* update LightSplitInfo too
parent 75979bac
...@@ -134,6 +134,11 @@ struct SplitInfo { ...@@ -134,6 +134,11 @@ struct SplitInfo {
if (other_gain == NAN) { if (other_gain == NAN) {
other_gain = kMinScore; other_gain = kMinScore;
} }
if (local_gain != other_gain) {
return local_gain > other_gain;
}
// if gains are identical, choose the feature with the smaller index
int local_feature = this->feature; int local_feature = this->feature;
int other_feature = si.feature; int other_feature = si.feature;
// replace -1 with max int // replace -1 with max int
...@@ -144,14 +149,10 @@ struct SplitInfo { ...@@ -144,14 +149,10 @@ struct SplitInfo {
if (other_feature == -1) { if (other_feature == -1) {
other_feature = INT32_MAX; other_feature = INT32_MAX;
} }
if (local_gain != other_gain) {
return local_gain > other_gain;
} else {
// if same gain, use smaller feature
return local_feature < other_feature; return local_feature < other_feature;
} }
}
/*! \brief test if a candidate SplitInfo is equivalent to this one */
inline bool operator == (const SplitInfo& si) const { inline bool operator == (const SplitInfo& si) const {
double local_gain = this->gain; double local_gain = this->gain;
double other_gain = si.gain; double other_gain = si.gain;
...@@ -163,6 +164,11 @@ struct SplitInfo { ...@@ -163,6 +164,11 @@ struct SplitInfo {
if (other_gain == NAN) { if (other_gain == NAN) {
other_gain = kMinScore; other_gain = kMinScore;
} }
if (local_gain != other_gain) {
return false;
}
// if same gain, splits are only equal if they also use the same feature
int local_feature = this->feature; int local_feature = this->feature;
int other_feature = si.feature; int other_feature = si.feature;
// replace -1 with max int // replace -1 with max int
...@@ -173,13 +179,8 @@ struct SplitInfo { ...@@ -173,13 +179,8 @@ struct SplitInfo {
if (other_feature == -1) { if (other_feature == -1) {
other_feature = INT32_MAX; other_feature = INT32_MAX;
} }
if (local_gain != other_gain) {
return local_gain == other_gain;
} else {
// if same gain, use smaller feature
return local_feature == other_feature; return local_feature == other_feature;
} }
}
}; };
struct LightSplitInfo { struct LightSplitInfo {
...@@ -228,6 +229,11 @@ struct LightSplitInfo { ...@@ -228,6 +229,11 @@ struct LightSplitInfo {
if (other_gain == NAN) { if (other_gain == NAN) {
other_gain = kMinScore; other_gain = kMinScore;
} }
if (local_gain != other_gain) {
return local_gain > other_gain;
}
// if gains are identical, choose the feature with the smaller index
int local_feature = this->feature; int local_feature = this->feature;
int other_feature = si.feature; int other_feature = si.feature;
// replace -1 with max int // replace -1 with max int
...@@ -238,14 +244,10 @@ struct LightSplitInfo { ...@@ -238,14 +244,10 @@ struct LightSplitInfo {
if (other_feature == -1) { if (other_feature == -1) {
other_feature = INT32_MAX; other_feature = INT32_MAX;
} }
if (local_gain != other_gain) {
return local_gain > other_gain;
} else {
// if same gain, use smaller feature
return local_feature < other_feature; return local_feature < other_feature;
} }
}
/*! \brief test if a candidate LightSplitInfo is equivalent to this one */
inline bool operator == (const LightSplitInfo& si) const { inline bool operator == (const LightSplitInfo& si) const {
double local_gain = this->gain; double local_gain = this->gain;
double other_gain = si.gain; double other_gain = si.gain;
...@@ -257,6 +259,11 @@ struct LightSplitInfo { ...@@ -257,6 +259,11 @@ struct LightSplitInfo {
if (other_gain == NAN) { if (other_gain == NAN) {
other_gain = kMinScore; other_gain = kMinScore;
} }
if (local_gain != other_gain) {
return false;
}
// if same gain, splits are only equal if they also use the same feature
int local_feature = this->feature; int local_feature = this->feature;
int other_feature = si.feature; int other_feature = si.feature;
// replace -1 with max int // replace -1 with max int
...@@ -267,13 +274,8 @@ struct LightSplitInfo { ...@@ -267,13 +274,8 @@ struct LightSplitInfo {
if (other_feature == -1) { if (other_feature == -1) {
other_feature = INT32_MAX; other_feature = INT32_MAX;
} }
if (local_gain != other_gain) {
return local_gain == other_gain;
} else {
// if same gain, use smaller feature
return local_feature == other_feature; return local_feature == other_feature;
} }
}
}; };
} // namespace LightGBM } // namespace LightGBM
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment