Commit 4aa63045 authored by Guolin Ke's avatar Guolin Ke
Browse files

support max_depth option.

parent 9895116d
...@@ -137,11 +137,17 @@ struct TreeConfig: public ConfigBase { ...@@ -137,11 +137,17 @@ struct TreeConfig: public ConfigBase {
public: public:
int min_data_in_leaf = 100; int min_data_in_leaf = 100;
double min_sum_hessian_in_leaf = 10.0f; double min_sum_hessian_in_leaf = 10.0f;
// should > 1, only one leaf means not need to learning
int num_leaves = 127; int num_leaves = 127;
int feature_fraction_seed = 2; int feature_fraction_seed = 2;
double feature_fraction = 1.0; double feature_fraction = 1.0;
// max cache size(unit:MB) for historical histogram. < 0 means not limit // max cache size(unit:MB) for historical histogram. < 0 means not limit
double histogram_pool_size = -1; double histogram_pool_size = -1;
// max depth of tree model.
// Still grow tree by leaf-wise, but limit the max depth to avoid over-fitting
// And the max leaves will be min(num_leaves, pow(2, max_depth - 1))
// max_depth < 0 means not limit
int max_depth = -1;
void Set(const std::unordered_map<std::string, std::string>& params) override; void Set(const std::unordered_map<std::string, std::string>& params) override;
}; };
......
...@@ -80,6 +80,9 @@ public: ...@@ -80,6 +80,9 @@ public:
/*! \brief Get Number of leaves*/ /*! \brief Get Number of leaves*/
inline int num_leaves() const { return num_leaves_; } inline int num_leaves() const { return num_leaves_; }
/*! \brief Get depth of specific leaf*/
inline int leaf_depth(int leaf_idx) const { return leaf_depth_[leaf_idx]; }
/*! /*!
* \brief Shrinkage for the tree's output * \brief Shrinkage for the tree's output
* shrinkage rate (a.k.a learning rate) is used to tune the traning process * shrinkage rate (a.k.a learning rate) is used to tune the traning process
...@@ -139,6 +142,8 @@ private: ...@@ -139,6 +142,8 @@ private:
int* leaf_parent_; int* leaf_parent_;
/*! \brief Output of leaves */ /*! \brief Output of leaves */
score_t* leaf_value_; score_t* leaf_value_;
/*! \brief Depth for leaves */
int* leaf_depth_;
}; };
......
...@@ -228,6 +228,8 @@ void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params) ...@@ -228,6 +228,8 @@ void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params)
GetDouble(params, "feature_fraction", &feature_fraction); GetDouble(params, "feature_fraction", &feature_fraction);
CHECK(feature_fraction > 0.0 && feature_fraction <= 1.0); CHECK(feature_fraction > 0.0 && feature_fraction <= 1.0);
GetDouble(params, "histogram_pool_size", &histogram_pool_size); GetDouble(params, "histogram_pool_size", &histogram_pool_size);
GetInt(params, "max_depth", &max_depth);
CHECK(max_depth > 1 || max_depth < 0);
} }
......
...@@ -28,6 +28,9 @@ Tree::Tree(int max_leaves) ...@@ -28,6 +28,9 @@ Tree::Tree(int max_leaves)
leaf_parent_ = new int[max_leaves_]; leaf_parent_ = new int[max_leaves_];
leaf_value_ = new score_t[max_leaves_]; leaf_value_ = new score_t[max_leaves_];
leaf_depth_ = new int[max_leaves_];
// root is in the depth 1
leaf_depth_[0] = 1;
num_leaves_ = 1; num_leaves_ = 1;
leaf_parent_[0] = -1; leaf_parent_[0] = -1;
} }
...@@ -41,6 +44,7 @@ Tree::~Tree() { ...@@ -41,6 +44,7 @@ Tree::~Tree() {
if (threshold_ != nullptr) { delete[] threshold_; } if (threshold_ != nullptr) { delete[] threshold_; }
if (split_gain_ != nullptr) { delete[] split_gain_; } if (split_gain_ != nullptr) { delete[] split_gain_; }
if (leaf_value_ != nullptr) { delete[] leaf_value_; } if (leaf_value_ != nullptr) { delete[] leaf_value_; }
if (leaf_depth_ != nullptr) { delete[] leaf_depth_; }
} }
int Tree::Split(int leaf, int feature, unsigned int threshold_bin, int real_feature, int Tree::Split(int leaf, int feature, unsigned int threshold_bin, int real_feature,
...@@ -70,9 +74,11 @@ int Tree::Split(int leaf, int feature, unsigned int threshold_bin, int real_feat ...@@ -70,9 +74,11 @@ int Tree::Split(int leaf, int feature, unsigned int threshold_bin, int real_feat
leaf_parent_[num_leaves_] = new_node_idx; leaf_parent_[num_leaves_] = new_node_idx;
leaf_value_[leaf] = left_value; leaf_value_[leaf] = left_value;
leaf_value_[num_leaves_] = right_value; leaf_value_[num_leaves_] = right_value;
// update leaf depth
leaf_depth_[num_leaves_] = leaf_depth_[leaf] + 1;
leaf_depth_[leaf]++;
++num_leaves_; ++num_leaves_;
return num_leaves_ - 1; return num_leaves_ - 1;
} }
...@@ -155,6 +161,7 @@ Tree::Tree(const std::string& str) { ...@@ -155,6 +161,7 @@ Tree::Tree(const std::string& str) {
split_feature_ = nullptr; split_feature_ = nullptr;
threshold_in_bin_ = nullptr; threshold_in_bin_ = nullptr;
leaf_depth_ = nullptr;
Common::StringToIntArray(key_vals["split_feature"], ' ', Common::StringToIntArray(key_vals["split_feature"], ' ',
num_leaves_ - 1, split_feature_real_); num_leaves_ - 1, split_feature_real_);
......
...@@ -19,6 +19,7 @@ SerialTreeLearner::SerialTreeLearner(const TreeConfig& tree_config) ...@@ -19,6 +19,7 @@ SerialTreeLearner::SerialTreeLearner(const TreeConfig& tree_config)
feature_fraction_ = tree_config.feature_fraction; feature_fraction_ = tree_config.feature_fraction;
random_ = Random(tree_config.feature_fraction_seed); random_ = Random(tree_config.feature_fraction_seed);
histogram_pool_size_ = tree_config.histogram_pool_size; histogram_pool_size_ = tree_config.histogram_pool_size;
max_depth_ = tree_config.max_depth;
} }
SerialTreeLearner::~SerialTreeLearner() { SerialTreeLearner::~SerialTreeLearner() {
...@@ -120,6 +121,8 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians ...@@ -120,6 +121,8 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
// some initial works before training // some initial works before training
BeforeTrain(); BeforeTrain();
Tree *tree = new Tree(num_leaves_); Tree *tree = new Tree(num_leaves_);
// save pointer to last trained tree
last_trained_tree_ = tree;
// root leaf // root leaf
int left_leaf = 0; int left_leaf = 0;
// only root leaf can be splitted on first time // only root leaf can be splitted on first time
...@@ -145,8 +148,6 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians ...@@ -145,8 +148,6 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
// split tree with best leaf // split tree with best leaf
Split(tree, best_leaf, &left_leaf, &right_leaf); Split(tree, best_leaf, &left_leaf, &right_leaf);
} }
// save pointer to last trained tree
last_trained_tree_ = tree;
return tree; return tree;
} }
...@@ -234,6 +235,17 @@ void SerialTreeLearner::BeforeTrain() { ...@@ -234,6 +235,17 @@ void SerialTreeLearner::BeforeTrain() {
} }
bool SerialTreeLearner::BeforeFindBestSplit(int left_leaf, int right_leaf) { bool SerialTreeLearner::BeforeFindBestSplit(int left_leaf, int right_leaf) {
// check depth of current leaf
if (max_depth_ > 0) {
// only need to check left leaf, since right leaf is in same level of left leaf
if (last_trained_tree_->leaf_depth(left_leaf) >= max_depth_) {
best_split_per_leaf_[left_leaf].gain = kMinScore;
if (right_leaf >= 0) {
best_split_per_leaf_[right_leaf].gain = kMinScore;
}
return false;
}
}
data_size_t num_data_in_left_child = GetGlobalDataCountInLeaf(left_leaf); data_size_t num_data_in_left_child = GetGlobalDataCountInLeaf(left_leaf);
data_size_t num_data_in_right_child = GetGlobalDataCountInLeaf(right_leaf); data_size_t num_data_in_right_child = GetGlobalDataCountInLeaf(right_leaf);
// no enough data to continue // no enough data to continue
......
...@@ -163,6 +163,8 @@ protected: ...@@ -163,6 +163,8 @@ protected:
double histogram_pool_size_; double histogram_pool_size_;
/*! \brief used to cache historical histogram to speed up*/ /*! \brief used to cache historical histogram to speed up*/
LRUPool<FeatureHistogram*> histogram_pool_; LRUPool<FeatureHistogram*> histogram_pool_;
/*! \brief max depth of tree model */
int max_depth_;
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment