Unverified Commit 9558417a authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

fix nan in tree model (#2303)

* fix nan in tree model

* fix
parent 578a8c8a
...@@ -422,7 +422,7 @@ inline void Tree::Split(int leaf, int feature, int real_feature, ...@@ -422,7 +422,7 @@ inline void Tree::Split(int leaf, int feature, int real_feature,
split_feature_inner_[new_node_idx] = feature; split_feature_inner_[new_node_idx] = feature;
split_feature_[new_node_idx] = real_feature; split_feature_[new_node_idx] = real_feature;
split_gain_[new_node_idx] = Common::AvoidInf(gain); split_gain_[new_node_idx] = gain;
// add two new leaves // add two new leaves
left_child_[new_node_idx] = ~leaf; left_child_[new_node_idx] = ~leaf;
right_child_[new_node_idx] = ~num_leaves_; right_child_[new_node_idx] = ~num_leaves_;
......
...@@ -663,7 +663,9 @@ inline static std::vector<int> VectorSize(const std::vector<std::vector<T>>& dat ...@@ -663,7 +663,9 @@ inline static std::vector<int> VectorSize(const std::vector<std::vector<T>>& dat
} }
inline static double AvoidInf(double x) { inline static double AvoidInf(double x) {
if (x >= 1e300) { if (std::isnan(x)) {
return 0.0;
} else if (x >= 1e300) {
return 1e300; return 1e300;
} else if (x <= -1e300) { } else if (x <= -1e300) {
return -1e300; return -1e300;
...@@ -673,7 +675,9 @@ inline static double AvoidInf(double x) { ...@@ -673,7 +675,9 @@ inline static double AvoidInf(double x) {
} }
inline static float AvoidInf(float x) { inline static float AvoidInf(float x) {
if (x >= 1e38) { if (std::isnan(x)){
return 0.0f;
} else if (x >= 1e38) {
return 1e38f; return 1e38f;
} else if (x <= -1e38) { } else if (x <= -1e38) {
return -1e38f; return -1e38f;
......
...@@ -64,7 +64,7 @@ int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin, ...@@ -64,7 +64,7 @@ int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,
SetMissingType(&decision_type_[new_node_idx], 2); SetMissingType(&decision_type_[new_node_idx], 2);
} }
threshold_in_bin_[new_node_idx] = threshold_bin; threshold_in_bin_[new_node_idx] = threshold_bin;
threshold_[new_node_idx] = Common::AvoidInf(threshold_double); threshold_[new_node_idx] = threshold_double;
++num_leaves_; ++num_leaves_;
return num_leaves_ - 1; return num_leaves_ - 1;
} }
...@@ -268,7 +268,7 @@ std::string Tree::NodeToJSON(int index) const { ...@@ -268,7 +268,7 @@ std::string Tree::NodeToJSON(int index) const {
str_buf << "{" << '\n'; str_buf << "{" << '\n';
str_buf << "\"split_index\":" << index << "," << '\n'; str_buf << "\"split_index\":" << index << "," << '\n';
str_buf << "\"split_feature\":" << split_feature_[index] << "," << '\n'; str_buf << "\"split_feature\":" << split_feature_[index] << "," << '\n';
str_buf << "\"split_gain\":" << split_gain_[index] << "," << '\n'; str_buf << "\"split_gain\":" << Common::AvoidInf(split_gain_[index]) << "," << '\n';
if (GetDecisionType(decision_type_[index], kCategoricalMask)) { if (GetDecisionType(decision_type_[index], kCategoricalMask)) {
int cat_idx = static_cast<int>(threshold_[index]); int cat_idx = static_cast<int>(threshold_[index]);
std::vector<int> cats; std::vector<int> cats;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment