Commit a6a75fe9 authored by Qiwei Ye's avatar Qiwei Ye Committed by GitHub
Browse files

Merge pull request #35 from guolinke/master

support max_depth option, to solve (#33)
parents 9895116d 4aa63045
......@@ -137,11 +137,17 @@ struct TreeConfig: public ConfigBase {
public:
int min_data_in_leaf = 100;
double min_sum_hessian_in_leaf = 10.0f;
// should > 1, only one leaf means not need to learning
int num_leaves = 127;
int feature_fraction_seed = 2;
double feature_fraction = 1.0;
// max cache size(unit:MB) for historical histogram. < 0 means not limit
double histogram_pool_size = -1;
// max depth of tree model.
// Still grow tree by leaf-wise, but limit the max depth to avoid over-fitting
// And the max leaves will be min(num_leaves, pow(2, max_depth - 1))
// max_depth < 0 means not limit
int max_depth = -1;
void Set(const std::unordered_map<std::string, std::string>& params) override;
};
......
......@@ -80,6 +80,9 @@ public:
/*! \brief Get Number of leaves*/
inline int num_leaves() const { return num_leaves_; }
/*! \brief Get depth of specific leaf*/
inline int leaf_depth(int leaf_idx) const { return leaf_depth_[leaf_idx]; }
/*!
* \brief Shrinkage for the tree's output
* shrinkage rate (a.k.a learning rate) is used to tune the traning process
......@@ -139,6 +142,8 @@ private:
int* leaf_parent_;
/*! \brief Output of leaves */
score_t* leaf_value_;
/*! \brief Depth for leaves */
int* leaf_depth_;
};
......
......@@ -228,6 +228,8 @@ void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params)
GetDouble(params, "feature_fraction", &feature_fraction);
CHECK(feature_fraction > 0.0 && feature_fraction <= 1.0);
GetDouble(params, "histogram_pool_size", &histogram_pool_size);
GetInt(params, "max_depth", &max_depth);
CHECK(max_depth > 1 || max_depth < 0);
}
......
......@@ -28,6 +28,9 @@ Tree::Tree(int max_leaves)
leaf_parent_ = new int[max_leaves_];
leaf_value_ = new score_t[max_leaves_];
leaf_depth_ = new int[max_leaves_];
// root is in the depth 1
leaf_depth_[0] = 1;
num_leaves_ = 1;
leaf_parent_[0] = -1;
}
......@@ -41,6 +44,7 @@ Tree::~Tree() {
if (threshold_ != nullptr) { delete[] threshold_; }
if (split_gain_ != nullptr) { delete[] split_gain_; }
if (leaf_value_ != nullptr) { delete[] leaf_value_; }
if (leaf_depth_ != nullptr) { delete[] leaf_depth_; }
}
int Tree::Split(int leaf, int feature, unsigned int threshold_bin, int real_feature,
......@@ -70,9 +74,11 @@ int Tree::Split(int leaf, int feature, unsigned int threshold_bin, int real_feat
leaf_parent_[num_leaves_] = new_node_idx;
leaf_value_[leaf] = left_value;
leaf_value_[num_leaves_] = right_value;
// update leaf depth
leaf_depth_[num_leaves_] = leaf_depth_[leaf] + 1;
leaf_depth_[leaf]++;
++num_leaves_;
return num_leaves_ - 1;
}
......@@ -155,6 +161,7 @@ Tree::Tree(const std::string& str) {
split_feature_ = nullptr;
threshold_in_bin_ = nullptr;
leaf_depth_ = nullptr;
Common::StringToIntArray(key_vals["split_feature"], ' ',
num_leaves_ - 1, split_feature_real_);
......
......@@ -19,6 +19,7 @@ SerialTreeLearner::SerialTreeLearner(const TreeConfig& tree_config)
feature_fraction_ = tree_config.feature_fraction;
random_ = Random(tree_config.feature_fraction_seed);
histogram_pool_size_ = tree_config.histogram_pool_size;
max_depth_ = tree_config.max_depth;
}
SerialTreeLearner::~SerialTreeLearner() {
......@@ -120,6 +121,8 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
// some initial works before training
BeforeTrain();
Tree *tree = new Tree(num_leaves_);
// save pointer to last trained tree
last_trained_tree_ = tree;
// root leaf
int left_leaf = 0;
// only root leaf can be splitted on first time
......@@ -145,8 +148,6 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
// split tree with best leaf
Split(tree, best_leaf, &left_leaf, &right_leaf);
}
// save pointer to last trained tree
last_trained_tree_ = tree;
return tree;
}
......@@ -234,6 +235,17 @@ void SerialTreeLearner::BeforeTrain() {
}
bool SerialTreeLearner::BeforeFindBestSplit(int left_leaf, int right_leaf) {
// check depth of current leaf
if (max_depth_ > 0) {
// only need to check left leaf, since right leaf is in same level of left leaf
if (last_trained_tree_->leaf_depth(left_leaf) >= max_depth_) {
best_split_per_leaf_[left_leaf].gain = kMinScore;
if (right_leaf >= 0) {
best_split_per_leaf_[right_leaf].gain = kMinScore;
}
return false;
}
}
data_size_t num_data_in_left_child = GetGlobalDataCountInLeaf(left_leaf);
data_size_t num_data_in_right_child = GetGlobalDataCountInLeaf(right_leaf);
// no enough data to continue
......
......@@ -163,6 +163,8 @@ protected:
double histogram_pool_size_;
/*! \brief used to cache historical histogram to speed up*/
LRUPool<FeatureHistogram*> histogram_pool_;
/*! \brief max depth of tree model */
int max_depth_;
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment