Unverified Commit 73bc8ed7 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

shrinkage to internal values (#2853)

parent 7d700cd3
...@@ -89,12 +89,7 @@ class Tree { ...@@ -89,12 +89,7 @@ class Tree {
/*! \brief Set the output of one leaf */ /*! \brief Set the output of one leaf */
inline void SetLeafOutput(int leaf, double output) { inline void SetLeafOutput(int leaf, double output) {
// Prevent denormal values because they can cause std::out_of_range exception when converting strings to doubles leaf_value_[leaf] = MaybeRoundToZero(output);
if (IsZero(output)) {
leaf_value_[leaf] = 0;
} else {
leaf_value_[leaf] = output;
}
} }
/*! /*!
...@@ -156,40 +151,33 @@ class Tree { ...@@ -156,40 +151,33 @@ class Tree {
/*! \brief Get the number of data points that fall at or below this node*/ /*! \brief Get the number of data points that fall at or below this node*/
inline int data_count(int node) const { return node >= 0 ? internal_count_[node] : leaf_count_[~node]; } inline int data_count(int node) const { return node >= 0 ? internal_count_[node] : leaf_count_[~node]; }
/*! /*!
* \brief Shrinkage for the tree's output * \brief Shrinkage for the tree's output
* shrinkage rate (a.k.a learning rate) is used to tune the training process * shrinkage rate (a.k.a learning rate) is used to tune the training process
* \param rate The factor of shrinkage * \param rate The factor of shrinkage
*/ */
inline void Shrinkage(double rate) { inline void Shrinkage(double rate) {
#pragma omp parallel for schedule(static, 1024) if (num_leaves_ >= 2048) #pragma omp parallel for schedule(static, 1024) if (num_leaves_ >= 2048)
for (int i = 0; i < num_leaves_; ++i) { for (int i = 0; i < num_leaves_ - 1; ++i) {
double new_leaf_value = leaf_value_[i] * rate; leaf_value_[i] = MaybeRoundToZero(leaf_value_[i] * rate);
// Prevent denormal values because they can cause std::out_of_range exception when converting strings to doubles internal_value_[i] = MaybeRoundToZero(internal_value_[i] * rate);
if (IsZero(new_leaf_value)) {
leaf_value_[i] = 0;
} else {
leaf_value_[i] = new_leaf_value;
}
} }
leaf_value_[num_leaves_ - 1] =
MaybeRoundToZero(leaf_value_[num_leaves_ - 1] * rate);
shrinkage_ *= rate; shrinkage_ *= rate;
} }
inline double shrinkage() const { inline double shrinkage() const { return shrinkage_; }
return shrinkage_;
}
inline void AddBias(double val) { inline void AddBias(double val) {
#pragma omp parallel for schedule(static, 1024) if (num_leaves_ >= 2048) #pragma omp parallel for schedule(static, 1024) if (num_leaves_ >= 2048)
for (int i = 0; i < num_leaves_; ++i) { for (int i = 0; i < num_leaves_ - 1; ++i) {
double new_leaf_value = val + leaf_value_[i]; leaf_value_[i] = MaybeRoundToZero(leaf_value_[i] + val);
// Prevent denormal values because they can cause std::out_of_range exception when converting strings to doubles internal_value_[i] = MaybeRoundToZero(internal_value_[i] + val);
if (IsZero(new_leaf_value)) {
leaf_value_[i] = 0;
} else {
leaf_value_[i] = new_leaf_value;
}
} }
leaf_value_[num_leaves_ - 1] =
MaybeRoundToZero(leaf_value_[num_leaves_ - 1] + val);
// force to 1.0 // force to 1.0
shrinkage_ = 1.0f; shrinkage_ = 1.0f;
} }
...@@ -217,6 +205,14 @@ class Tree { ...@@ -217,6 +205,14 @@ class Tree {
} }
} }
inline static double MaybeRoundToZero(double fval) {
if (fval > -kZeroThreshold && fval <= kZeroThreshold) {
return 0;
} else {
return fval;
}
}
inline static bool GetDecisionType(int8_t decision_type, int8_t mask) { inline static bool GetDecisionType(int8_t decision_type, int8_t mask) {
return (decision_type & mask) > 0; return (decision_type & mask) > 0;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment