Unverified Commit 51f37e9b authored by Alberto Ferreira's avatar Alberto Ferreira Committed by GitHub
Browse files

Cleanup MissingType enum constants (#2931)



* [refactor] Cleanup MissingType enum constants

* Update tree.cpp
Co-authored-by: default avatarAlberto Ferreira <alberto.ferreira@feedzai.com>
parent 2d4f3909
...@@ -257,13 +257,11 @@ class Tree { ...@@ -257,13 +257,11 @@ class Tree {
inline int NumericalDecision(double fval, int node) const { inline int NumericalDecision(double fval, int node) const {
uint8_t missing_type = GetMissingType(decision_type_[node]); uint8_t missing_type = GetMissingType(decision_type_[node]);
if (std::isnan(fval)) { if (std::isnan(fval) && missing_type != MissingType::NaN) {
if (missing_type != 2) { fval = 0.0f;
fval = 0.0f;
}
} }
if ((missing_type == 1 && IsZero(fval)) if ((missing_type == MissingType::Zero && IsZero(fval))
|| (missing_type == 2 && std::isnan(fval))) { || (missing_type == MissingType::NaN && std::isnan(fval))) {
if (GetDecisionType(decision_type_[node], kDefaultLeftMask)) { if (GetDecisionType(decision_type_[node], kDefaultLeftMask)) {
return left_child_[node]; return left_child_[node];
} else { } else {
...@@ -279,8 +277,8 @@ class Tree { ...@@ -279,8 +277,8 @@ class Tree {
inline int NumericalDecisionInner(uint32_t fval, int node, uint32_t default_bin, uint32_t max_bin) const { inline int NumericalDecisionInner(uint32_t fval, int node, uint32_t default_bin, uint32_t max_bin) const {
uint8_t missing_type = GetMissingType(decision_type_[node]); uint8_t missing_type = GetMissingType(decision_type_[node]);
if ((missing_type == 1 && fval == default_bin) if ((missing_type == MissingType::Zero && fval == default_bin)
|| (missing_type == 2 && fval == max_bin)) { || (missing_type == MissingType::NaN && fval == max_bin)) {
if (GetDecisionType(decision_type_[node], kDefaultLeftMask)) { if (GetDecisionType(decision_type_[node], kDefaultLeftMask)) {
return left_child_[node]; return left_child_[node];
} else { } else {
...@@ -301,7 +299,7 @@ class Tree { ...@@ -301,7 +299,7 @@ class Tree {
return right_child_[node];; return right_child_[node];;
} else if (std::isnan(fval)) { } else if (std::isnan(fval)) {
// NaN is always in the right // NaN is always in the right
if (missing_type == 2) { if (missing_type == MissingType::NaN) {
return right_child_[node]; return right_child_[node];
} }
int_fval = 0; int_fval = 0;
......
...@@ -57,13 +57,7 @@ int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin, ...@@ -57,13 +57,7 @@ int Tree::Split(int leaf, int feature, int real_feature, uint32_t threshold_bin,
decision_type_[new_node_idx] = 0; decision_type_[new_node_idx] = 0;
SetDecisionType(&decision_type_[new_node_idx], false, kCategoricalMask); SetDecisionType(&decision_type_[new_node_idx], false, kCategoricalMask);
SetDecisionType(&decision_type_[new_node_idx], default_left, kDefaultLeftMask); SetDecisionType(&decision_type_[new_node_idx], default_left, kDefaultLeftMask);
if (missing_type == MissingType::None) { SetMissingType(&decision_type_[new_node_idx], missing_type);
SetMissingType(&decision_type_[new_node_idx], 0);
} else if (missing_type == MissingType::Zero) {
SetMissingType(&decision_type_[new_node_idx], 1);
} else if (missing_type == MissingType::NaN) {
SetMissingType(&decision_type_[new_node_idx], 2);
}
threshold_in_bin_[new_node_idx] = threshold_bin; threshold_in_bin_[new_node_idx] = threshold_bin;
threshold_[new_node_idx] = threshold_double; threshold_[new_node_idx] = threshold_double;
++num_leaves_; ++num_leaves_;
...@@ -77,13 +71,7 @@ int Tree::SplitCategorical(int leaf, int feature, int real_feature, const uint32 ...@@ -77,13 +71,7 @@ int Tree::SplitCategorical(int leaf, int feature, int real_feature, const uint32
int new_node_idx = num_leaves_ - 1; int new_node_idx = num_leaves_ - 1;
decision_type_[new_node_idx] = 0; decision_type_[new_node_idx] = 0;
SetDecisionType(&decision_type_[new_node_idx], true, kCategoricalMask); SetDecisionType(&decision_type_[new_node_idx], true, kCategoricalMask);
if (missing_type == MissingType::None) { SetMissingType(&decision_type_[new_node_idx], missing_type);
SetMissingType(&decision_type_[new_node_idx], 0);
} else if (missing_type == MissingType::Zero) {
SetMissingType(&decision_type_[new_node_idx], 1);
} else if (missing_type == MissingType::NaN) {
SetMissingType(&decision_type_[new_node_idx], 2);
}
threshold_in_bin_[new_node_idx] = num_cat_; threshold_in_bin_[new_node_idx] = num_cat_;
threshold_[new_node_idx] = num_cat_; threshold_[new_node_idx] = num_cat_;
++num_cat_; ++num_cat_;
...@@ -316,9 +304,9 @@ std::string Tree::NodeToJSON(int index) const { ...@@ -316,9 +304,9 @@ std::string Tree::NodeToJSON(int index) const {
str_buf << "\"default_left\":false," << '\n'; str_buf << "\"default_left\":false," << '\n';
} }
uint8_t missing_type = GetMissingType(decision_type_[index]); uint8_t missing_type = GetMissingType(decision_type_[index]);
if (missing_type == 0) { if (missing_type == MissingType::None) {
str_buf << "\"missing_type\":\"None\"," << '\n'; str_buf << "\"missing_type\":\"None\"," << '\n';
} else if (missing_type == 1) { } else if (missing_type == MissingType::Zero) {
str_buf << "\"missing_type\":\"Zero\"," << '\n'; str_buf << "\"missing_type\":\"Zero\"," << '\n';
} else { } else {
str_buf << "\"missing_type\":\"NaN\"," << '\n'; str_buf << "\"missing_type\":\"NaN\"," << '\n';
...@@ -347,9 +335,10 @@ std::string Tree::NumericalDecisionIfElse(int node) const { ...@@ -347,9 +335,10 @@ std::string Tree::NumericalDecisionIfElse(int node) const {
std::stringstream str_buf; std::stringstream str_buf;
uint8_t missing_type = GetMissingType(decision_type_[node]); uint8_t missing_type = GetMissingType(decision_type_[node]);
bool default_left = GetDecisionType(decision_type_[node], kDefaultLeftMask); bool default_left = GetDecisionType(decision_type_[node], kDefaultLeftMask);
if (missing_type == 0 || (missing_type == 1 && default_left && kZeroThreshold < threshold_[node])) { if (missing_type == MissingType::None
|| (missing_type == MissingType::Zero && default_left && kZeroThreshold < threshold_[node])) {
str_buf << "if (fval <= " << threshold_[node] << ") {"; str_buf << "if (fval <= " << threshold_[node] << ") {";
} else if (missing_type == 1) { } else if (missing_type == MissingType::Zero) {
if (default_left) { if (default_left) {
str_buf << "if (fval <= " << threshold_[node] << " || Tree::IsZero(fval)" << " || std::isnan(fval)) {"; str_buf << "if (fval <= " << threshold_[node] << " || Tree::IsZero(fval)" << " || std::isnan(fval)) {";
} else { } else {
...@@ -368,7 +357,7 @@ std::string Tree::NumericalDecisionIfElse(int node) const { ...@@ -368,7 +357,7 @@ std::string Tree::NumericalDecisionIfElse(int node) const {
std::string Tree::CategoricalDecisionIfElse(int node) const { std::string Tree::CategoricalDecisionIfElse(int node) const {
uint8_t missing_type = GetMissingType(decision_type_[node]); uint8_t missing_type = GetMissingType(decision_type_[node]);
std::stringstream str_buf; std::stringstream str_buf;
if (missing_type == 2) { if (missing_type == MissingType::NaN) {
str_buf << "if (std::isnan(fval)) { int_fval = -1; } else { int_fval = static_cast<int>(fval); }"; str_buf << "if (std::isnan(fval)) { int_fval = -1; } else { int_fval = static_cast<int>(fval); }";
} else { } else {
str_buf << "if (std::isnan(fval)) { int_fval = 0; } else { int_fval = static_cast<int>(fval); }"; str_buf << "if (std::isnan(fval)) { int_fval = 0; } else { int_fval = static_cast<int>(fval); }";
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment