Unverified Commit b52ecb16 authored by James Lamb's avatar James Lamb Committed by GitHub
Browse files

simplify and speed up comparisons for splits with identical gains (#4542)

* fix incorrect behavior of SplitInfo == operator for splits with identical gains

* LightSplitInfo too, and improve comment

* dont check features unnecessarily

* update LightSplitInfo too
parent 75979bac
...@@ -134,6 +134,11 @@ struct SplitInfo { ...@@ -134,6 +134,11 @@ struct SplitInfo {
if (other_gain == NAN) { if (other_gain == NAN) {
other_gain = kMinScore; other_gain = kMinScore;
} }
if (local_gain != other_gain) {
return local_gain > other_gain;
}
// if gains are identical, choose the feature with the smaller index
int local_feature = this->feature; int local_feature = this->feature;
int other_feature = si.feature; int other_feature = si.feature;
// replace -1 with max int // replace -1 with max int
...@@ -144,14 +149,10 @@ struct SplitInfo { ...@@ -144,14 +149,10 @@ struct SplitInfo {
if (other_feature == -1) { if (other_feature == -1) {
other_feature = INT32_MAX; other_feature = INT32_MAX;
} }
if (local_gain != other_gain) { return local_feature < other_feature;
return local_gain > other_gain;
} else {
// if same gain, use smaller feature
return local_feature < other_feature;
}
} }
/*! \brief test if a candidate SplitInfo is equivalent to this one */
inline bool operator == (const SplitInfo& si) const { inline bool operator == (const SplitInfo& si) const {
double local_gain = this->gain; double local_gain = this->gain;
double other_gain = si.gain; double other_gain = si.gain;
...@@ -163,6 +164,11 @@ struct SplitInfo { ...@@ -163,6 +164,11 @@ struct SplitInfo {
if (other_gain == NAN) { if (other_gain == NAN) {
other_gain = kMinScore; other_gain = kMinScore;
} }
if (local_gain != other_gain) {
return false;
}
// if same gain, splits are only equal if they also use the same feature
int local_feature = this->feature; int local_feature = this->feature;
int other_feature = si.feature; int other_feature = si.feature;
// replace -1 with max int // replace -1 with max int
...@@ -173,12 +179,7 @@ struct SplitInfo { ...@@ -173,12 +179,7 @@ struct SplitInfo {
if (other_feature == -1) { if (other_feature == -1) {
other_feature = INT32_MAX; other_feature = INT32_MAX;
} }
if (local_gain != other_gain) { return local_feature == other_feature;
return local_gain == other_gain;
} else {
// if same gain, use smaller feature
return local_feature == other_feature;
}
} }
}; };
...@@ -228,6 +229,11 @@ struct LightSplitInfo { ...@@ -228,6 +229,11 @@ struct LightSplitInfo {
if (other_gain == NAN) { if (other_gain == NAN) {
other_gain = kMinScore; other_gain = kMinScore;
} }
if (local_gain != other_gain) {
return local_gain > other_gain;
}
// if gains are identical, choose the feature with the smaller index
int local_feature = this->feature; int local_feature = this->feature;
int other_feature = si.feature; int other_feature = si.feature;
// replace -1 with max int // replace -1 with max int
...@@ -238,14 +244,10 @@ struct LightSplitInfo { ...@@ -238,14 +244,10 @@ struct LightSplitInfo {
if (other_feature == -1) { if (other_feature == -1) {
other_feature = INT32_MAX; other_feature = INT32_MAX;
} }
if (local_gain != other_gain) { return local_feature < other_feature;
return local_gain > other_gain;
} else {
// if same gain, use smaller feature
return local_feature < other_feature;
}
} }
/*! \brief test if a candidate LightSplitInfo is equivalent to this one */
inline bool operator == (const LightSplitInfo& si) const { inline bool operator == (const LightSplitInfo& si) const {
double local_gain = this->gain; double local_gain = this->gain;
double other_gain = si.gain; double other_gain = si.gain;
...@@ -257,6 +259,11 @@ struct LightSplitInfo { ...@@ -257,6 +259,11 @@ struct LightSplitInfo {
if (other_gain == NAN) { if (other_gain == NAN) {
other_gain = kMinScore; other_gain = kMinScore;
} }
if (local_gain != other_gain) {
return false;
}
// if same gain, splits are only equal if they also use the same feature
int local_feature = this->feature; int local_feature = this->feature;
int other_feature = si.feature; int other_feature = si.feature;
// replace -1 with max int // replace -1 with max int
...@@ -267,12 +274,7 @@ struct LightSplitInfo { ...@@ -267,12 +274,7 @@ struct LightSplitInfo {
if (other_feature == -1) { if (other_feature == -1) {
other_feature = INT32_MAX; other_feature = INT32_MAX;
} }
if (local_gain != other_gain) { return local_feature == other_feature;
return local_gain == other_gain;
} else {
// if same gain, use smaller feature
return local_feature == other_feature;
}
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment