Commit 1466f907 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

Categorical feature support (#108)

Categorical feature support (#108)
parent 531352f6
...@@ -100,7 +100,7 @@ public: ...@@ -100,7 +100,7 @@ public:
} }
} }
data_size_t Split(unsigned int threshold, data_size_t* data_indices, data_size_t num_data, virtual data_size_t Split(unsigned int threshold, data_size_t* data_indices, data_size_t num_data,
data_size_t* lte_indices, data_size_t* gt_indices) const override { data_size_t* lte_indices, data_size_t* gt_indices) const override {
// not need to split // not need to split
if (num_data <= 0) { return 0; } if (num_data <= 0) { return 0; }
...@@ -261,7 +261,7 @@ public: ...@@ -261,7 +261,7 @@ public:
} }
private: protected:
data_size_t num_data_; data_size_t num_data_;
std::vector<std::pair<data_size_t, VAL_T>> non_zero_pair_; std::vector<std::pair<data_size_t, VAL_T>> non_zero_pair_;
std::vector<uint8_t> deltas_; std::vector<uint8_t> deltas_;
...@@ -299,6 +299,33 @@ BinIterator* SparseBin<VAL_T>::GetIterator(data_size_t start_idx) const { ...@@ -299,6 +299,33 @@ BinIterator* SparseBin<VAL_T>::GetIterator(data_size_t start_idx) const {
} }
template <typename VAL_T>
class SparseCategoricalBin: public SparseBin<VAL_T> {
public:
SparseCategoricalBin(data_size_t num_data, int default_bin)
: SparseBin<VAL_T>(num_data, default_bin) {
}
virtual data_size_t Split(unsigned int threshold, data_size_t* data_indices, data_size_t num_data,
data_size_t* lte_indices, data_size_t* gt_indices) const override {
// not need to split
if (num_data <= 0) { return 0; }
SparseBinIterator<VAL_T> iterator(this, data_indices[0]);
data_size_t lte_count = 0;
data_size_t gt_count = 0;
for (data_size_t i = 0; i < num_data; ++i) {
const data_size_t idx = data_indices[i];
VAL_T bin = iterator.InnerGet(idx);
if (bin != threshold) {
gt_indices[gt_count++] = idx;
} else {
lte_indices[lte_count++] = idx;
}
}
return lte_count;
}
};
} // namespace LightGBM } // namespace LightGBM
#endif // LightGBM_IO_SPARSE_BIN_HPP_ #endif // LightGBM_IO_SPARSE_BIN_HPP_
...@@ -15,6 +15,12 @@ ...@@ -15,6 +15,12 @@
namespace LightGBM { namespace LightGBM {
std::vector<std::function<bool(unsigned int, unsigned int)>> Tree::inner_decision_funs =
{Tree::NumericalDecision<unsigned int>, Tree::CategoricalDecision<unsigned int> };
std::vector<std::function<bool(double, double)>> Tree::decision_funs =
{ Tree::NumericalDecision<double>, Tree::CategoricalDecision<double> };
Tree::Tree(int max_leaves) Tree::Tree(int max_leaves)
:max_leaves_(max_leaves) { :max_leaves_(max_leaves) {
...@@ -25,10 +31,13 @@ Tree::Tree(int max_leaves) ...@@ -25,10 +31,13 @@ Tree::Tree(int max_leaves)
split_feature_real_ = std::vector<int>(max_leaves_ - 1); split_feature_real_ = std::vector<int>(max_leaves_ - 1);
threshold_in_bin_ = std::vector<unsigned int>(max_leaves_ - 1); threshold_in_bin_ = std::vector<unsigned int>(max_leaves_ - 1);
threshold_ = std::vector<double>(max_leaves_ - 1); threshold_ = std::vector<double>(max_leaves_ - 1);
decision_type_ = std::vector<int8_t>(max_leaves_ - 1);
split_gain_ = std::vector<double>(max_leaves_ - 1); split_gain_ = std::vector<double>(max_leaves_ - 1);
leaf_parent_ = std::vector<int>(max_leaves_); leaf_parent_ = std::vector<int>(max_leaves_);
leaf_value_ = std::vector<double>(max_leaves_); leaf_value_ = std::vector<double>(max_leaves_);
leaf_count_ = std::vector<data_size_t>(max_leaves_);
internal_value_ = std::vector<double>(max_leaves_ - 1); internal_value_ = std::vector<double>(max_leaves_ - 1);
internal_count_ = std::vector<data_size_t>(max_leaves_ - 1);
leaf_depth_ = std::vector<int>(max_leaves_); leaf_depth_ = std::vector<int>(max_leaves_);
// root is in the depth 1 // root is in the depth 1
leaf_depth_[0] = 1; leaf_depth_[0] = 1;
...@@ -39,8 +48,9 @@ Tree::~Tree() { ...@@ -39,8 +48,9 @@ Tree::~Tree() {
} }
int Tree::Split(int leaf, int feature, unsigned int threshold_bin, int real_feature, int Tree::Split(int leaf, int feature, BinType bin_type, unsigned int threshold_bin, int real_feature,
double threshold, double left_value, double right_value, double gain) { double threshold_double, double left_value,
double right_value, data_size_t left_cnt, data_size_t right_cnt, double gain) {
int new_node_idx = num_leaves_ - 1; int new_node_idx = num_leaves_ - 1;
// update parent info // update parent info
int parent = leaf_parent_[leaf]; int parent = leaf_parent_[leaf];
...@@ -56,7 +66,12 @@ int Tree::Split(int leaf, int feature, unsigned int threshold_bin, int real_feat ...@@ -56,7 +66,12 @@ int Tree::Split(int leaf, int feature, unsigned int threshold_bin, int real_feat
split_feature_[new_node_idx] = feature; split_feature_[new_node_idx] = feature;
split_feature_real_[new_node_idx] = real_feature; split_feature_real_[new_node_idx] = real_feature;
threshold_in_bin_[new_node_idx] = threshold_bin; threshold_in_bin_[new_node_idx] = threshold_bin;
threshold_[new_node_idx] = threshold; threshold_[new_node_idx] = threshold_double;
if (bin_type == BinType::NumericalBin) {
decision_type_[new_node_idx] = 0;
} else {
decision_type_[new_node_idx] = 1;
}
split_gain_[new_node_idx] = gain; split_gain_[new_node_idx] = gain;
// add two new leaves // add two new leaves
left_child_[new_node_idx] = ~leaf; left_child_[new_node_idx] = ~leaf;
...@@ -66,8 +81,11 @@ int Tree::Split(int leaf, int feature, unsigned int threshold_bin, int real_feat ...@@ -66,8 +81,11 @@ int Tree::Split(int leaf, int feature, unsigned int threshold_bin, int real_feat
leaf_parent_[num_leaves_] = new_node_idx; leaf_parent_[num_leaves_] = new_node_idx;
// save current leaf value to internal node before change // save current leaf value to internal node before change
internal_value_[new_node_idx] = leaf_value_[leaf]; internal_value_[new_node_idx] = leaf_value_[leaf];
internal_count_[new_node_idx] = left_cnt + right_cnt;
leaf_value_[leaf] = left_value; leaf_value_[leaf] = left_value;
leaf_count_[leaf] = left_cnt;
leaf_value_[num_leaves_] = right_value; leaf_value_[num_leaves_] = right_value;
leaf_count_[num_leaves_] = right_cnt;
// update leaf depth // update leaf depth
leaf_depth_[num_leaves_] = leaf_depth_[leaf] + 1; leaf_depth_[num_leaves_] = leaf_depth_[leaf] + 1;
leaf_depth_[leaf]++; leaf_depth_[leaf]++;
...@@ -106,21 +124,27 @@ std::string Tree::ToString() { ...@@ -106,21 +124,27 @@ std::string Tree::ToString() {
std::stringstream ss; std::stringstream ss;
ss << "num_leaves=" << num_leaves_ << std::endl; ss << "num_leaves=" << num_leaves_ << std::endl;
ss << "split_feature=" ss << "split_feature="
<< Common::ArrayToString<int>(split_feature_real_.data(), num_leaves_ - 1, ' ') << std::endl; << Common::ArrayToString<int>(split_feature_real_, ' ') << std::endl;
ss << "split_gain=" ss << "split_gain="
<< Common::ArrayToString<double>(split_gain_.data(), num_leaves_ - 1, ' ') << std::endl; << Common::ArrayToString<double>(split_gain_, ' ') << std::endl;
ss << "threshold=" ss << "threshold="
<< Common::ArrayToString<double>(threshold_.data(), num_leaves_ - 1, ' ') << std::endl; << Common::ArrayToString<double>(threshold_, ' ') << std::endl;
ss << "decision_type="
<< Common::ArrayToString<int>(Common::ArrayCast<int8_t, int>(decision_type_), ' ') << std::endl;
ss << "left_child=" ss << "left_child="
<< Common::ArrayToString<int>(left_child_.data(), num_leaves_ - 1, ' ') << std::endl; << Common::ArrayToString<int>(left_child_, ' ') << std::endl;
ss << "right_child=" ss << "right_child="
<< Common::ArrayToString<int>(right_child_.data(), num_leaves_ - 1, ' ') << std::endl; << Common::ArrayToString<int>(right_child_, ' ') << std::endl;
ss << "leaf_parent=" ss << "leaf_parent="
<< Common::ArrayToString<int>(leaf_parent_.data(), num_leaves_, ' ') << std::endl; << Common::ArrayToString<int>(leaf_parent_, ' ') << std::endl;
ss << "leaf_value=" ss << "leaf_value="
<< Common::ArrayToString<double>(leaf_value_.data(), num_leaves_, ' ') << std::endl; << Common::ArrayToString<double>(leaf_value_, ' ') << std::endl;
ss << "leaf_count="
<< Common::ArrayToString<data_size_t>(leaf_count_, ' ') << std::endl;
ss << "internal_value=" ss << "internal_value="
<< Common::ArrayToString<double>(internal_value_.data(), num_leaves_ - 1, ' ') << std::endl; << Common::ArrayToString<double>(internal_value_, ' ') << std::endl;
ss << "internal_count="
<< Common::ArrayToString<data_size_t>(internal_count_, ' ') << std::endl;
ss << std::endl; ss << std::endl;
return ss.str(); return ss.str();
} }
...@@ -142,20 +166,23 @@ std::string Tree::NodeToJSON(int index) { ...@@ -142,20 +166,23 @@ std::string Tree::NodeToJSON(int index) {
// non-leaf // non-leaf
ss << "{" << std::endl; ss << "{" << std::endl;
ss << "\"split_index\":" << index << "," << std::endl; ss << "\"split_index\":" << index << "," << std::endl;
ss << "\"split_feature\":" << split_feature_real_.data()[index] << "," << std::endl; ss << "\"split_feature\":" << split_feature_real_[index] << "," << std::endl;
ss << "\"split_gain\":" << split_gain_.data()[index] << "," << std::endl; ss << "\"split_gain\":" << split_gain_[index] << "," << std::endl;
ss << "\"threshold\":" << threshold_.data()[index] << "," << std::endl; ss << "\"threshold\":" << threshold_[index] << "," << std::endl;
ss << "\"internal_value\":" << internal_value_.data()[index] << "," << std::endl; ss << "\"decision_type\":\"" << Tree::GetDecisionTypeName(decision_type_[index]) << "\"," << std::endl;
ss << "\"left_child\":" << NodeToJSON(left_child_.data()[index]) << "," << std::endl; ss << "\"internal_value\":" << internal_value_[index] << "," << std::endl;
ss << "\"right_child\":" << NodeToJSON(right_child_.data()[index]) << std::endl; ss << "\"internal_count\":" << internal_count_[index] << "," << std::endl;
ss << "\"left_child\":" << NodeToJSON(left_child_[index]) << "," << std::endl;
ss << "\"right_child\":" << NodeToJSON(right_child_[index]) << std::endl;
ss << "}"; ss << "}";
} else { } else {
// leaf // leaf
index = ~index; index = ~index;
ss << "{" << std::endl; ss << "{" << std::endl;
ss << "\"leaf_index\":" << index << "," << std::endl; ss << "\"leaf_index\":" << index << "," << std::endl;
ss << "\"leaf_parent\":" << leaf_parent_.data()[index] << "," << std::endl; ss << "\"leaf_parent\":" << leaf_parent_[index] << "," << std::endl;
ss << "\"leaf_value\":" << leaf_value_.data()[index] << std::endl; ss << "\"leaf_value\":" << leaf_value_[index] << "," << std::endl;
ss << "\"leaf_count\":" << leaf_count_[index] << std::endl;
ss << "}"; ss << "}";
} }
...@@ -179,37 +206,29 @@ Tree::Tree(const std::string& str) { ...@@ -179,37 +206,29 @@ Tree::Tree(const std::string& str) {
|| key_vals.count("split_gain") <= 0 || key_vals.count("threshold") <= 0 || key_vals.count("split_gain") <= 0 || key_vals.count("threshold") <= 0
|| key_vals.count("left_child") <= 0 || key_vals.count("right_child") <= 0 || key_vals.count("left_child") <= 0 || key_vals.count("right_child") <= 0
|| key_vals.count("leaf_parent") <= 0 || key_vals.count("leaf_value") <= 0 || key_vals.count("leaf_parent") <= 0 || key_vals.count("leaf_value") <= 0
|| key_vals.count("internal_value") <= 0) { || key_vals.count("internal_value") <= 0 || key_vals.count("internal_count") <= 0
|| key_vals.count("leaf_count") <= 0 || key_vals.count("decision_type") <= 0
) {
Log::Fatal("Tree model string format error"); Log::Fatal("Tree model string format error");
} }
Common::Atoi(key_vals["num_leaves"].c_str(), &num_leaves_); Common::Atoi(key_vals["num_leaves"].c_str(), &num_leaves_);
left_child_ = std::vector<int>(num_leaves_ - 1); left_child_ = Common::StringToArray<int>(key_vals["left_child"], ' ', num_leaves_ - 1);
right_child_ = std::vector<int>(num_leaves_ - 1); right_child_ = Common::StringToArray<int>(key_vals["right_child"], ' ', num_leaves_ - 1);
split_feature_real_ = std::vector<int>(num_leaves_ - 1); split_feature_real_ = Common::StringToArray<int>(key_vals["split_feature"], ' ', num_leaves_ - 1);
threshold_ = std::vector<double>(num_leaves_ - 1); threshold_ = Common::StringToArray<double>(key_vals["threshold"], ' ', num_leaves_ - 1);
split_gain_ = std::vector<double>(num_leaves_ - 1); split_gain_ = Common::StringToArray<double>(key_vals["split_gain"], ' ', num_leaves_ - 1);
leaf_parent_ = std::vector<int>(num_leaves_); internal_count_ = Common::StringToArray<data_size_t>(key_vals["internal_count"], ' ', num_leaves_ - 1);
leaf_value_ = std::vector<double>(num_leaves_); internal_value_ = Common::StringToArray<double>(key_vals["internal_value"], ' ', num_leaves_ - 1);
internal_value_ = std::vector<double>(num_leaves_ - 1); decision_type_ = Common::StringToArray<int8_t>(key_vals["decision_type"], ' ', num_leaves_ - 1);
Common::StringToIntArray(key_vals["split_feature"], ' ', leaf_count_ = Common::StringToArray<data_size_t>(key_vals["leaf_count"], ' ', num_leaves_);
num_leaves_ - 1, split_feature_real_.data()); leaf_parent_ = Common::StringToArray<int>(key_vals["leaf_parent"], ' ', num_leaves_);
Common::StringToDoubleArray(key_vals["split_gain"], ' ', leaf_value_ = Common::StringToArray<double>(key_vals["leaf_value"], ' ', num_leaves_);
num_leaves_ - 1, split_gain_.data());
Common::StringToDoubleArray(key_vals["threshold"], ' ',
num_leaves_ - 1, threshold_.data());
Common::StringToIntArray(key_vals["left_child"], ' ',
num_leaves_ - 1, left_child_.data());
Common::StringToIntArray(key_vals["right_child"], ' ',
num_leaves_ - 1, right_child_.data());
Common::StringToIntArray(key_vals["leaf_parent"], ' ',
num_leaves_ , leaf_parent_.data());
Common::StringToDoubleArray(key_vals["leaf_value"], ' ',
num_leaves_ , leaf_value_.data());
Common::StringToDoubleArray(key_vals["internal_value"], ' ',
num_leaves_ - 1 , internal_value_.data());
} }
} // namespace LightGBM } // namespace LightGBM
...@@ -37,7 +37,7 @@ void DataParallelTreeLearner::Init(const Dataset* train_data) { ...@@ -37,7 +37,7 @@ void DataParallelTreeLearner::Init(const Dataset* train_data) {
buffer_write_start_pos_.resize(num_features_); buffer_write_start_pos_.resize(num_features_);
buffer_read_start_pos_.resize(num_features_); buffer_read_start_pos_.resize(num_features_);
global_data_count_in_leaf_.resize(num_leaves_); global_data_count_in_leaf_.resize(tree_config_.num_leaves);
} }
......
...@@ -28,17 +28,17 @@ public: ...@@ -28,17 +28,17 @@ public:
* \param feature the feature data for this histogram * \param feature the feature data for this histogram
* \param min_num_data_one_leaf minimal number of data in one leaf * \param min_num_data_one_leaf minimal number of data in one leaf
*/ */
void Init(const Feature* feature, int feature_idx, data_size_t min_num_data_one_leaf, void Init(const Feature* feature, int feature_idx, const TreeConfig* tree_config) {
double min_sum_hessian_one_leaf, double lambda_l1, double lambda_l2, double min_gain_to_split) {
feature_idx_ = feature_idx; feature_idx_ = feature_idx;
min_num_data_one_leaf_ = min_num_data_one_leaf; tree_config_ = tree_config;
min_sum_hessian_one_leaf_ = min_sum_hessian_one_leaf;
lambda_l1_ = lambda_l1;
lambda_l2_ = lambda_l2;
min_gain_to_split_ = min_gain_to_split;
bin_data_ = feature->bin_data(); bin_data_ = feature->bin_data();
num_bins_ = feature->num_bin(); num_bins_ = feature->num_bin();
data_.resize(num_bins_); data_.resize(num_bins_);
if (feature->bin_type() == BinType::NumericalBin) {
find_best_threshold_fun_ = std::bind(&FeatureHistogram::FindBestThresholdForNumerical, this, std::placeholders::_1);
} else {
find_best_threshold_fun_ = std::bind(&FeatureHistogram::FindBestThresholdForCategorical, this, std::placeholders::_1);
}
} }
...@@ -110,6 +110,10 @@ public: ...@@ -110,6 +110,10 @@ public:
* \param output The best split result * \param output The best split result
*/ */
void FindBestThreshold(SplitInfo* output) { void FindBestThreshold(SplitInfo* output) {
find_best_threshold_fun_(output);
}
void FindBestThresholdForNumerical(SplitInfo* output) {
double best_sum_left_gradient = NAN; double best_sum_left_gradient = NAN;
double best_sum_left_hessian = NAN; double best_sum_left_hessian = NAN;
double best_gain = kMinScore; double best_gain = kMinScore;
...@@ -119,7 +123,7 @@ public: ...@@ -119,7 +123,7 @@ public:
double sum_right_hessian = kEpsilon; double sum_right_hessian = kEpsilon;
data_size_t right_count = 0; data_size_t right_count = 0;
double gain_shift = GetLeafSplitGain(sum_gradients_, sum_hessians_); double gain_shift = GetLeafSplitGain(sum_gradients_, sum_hessians_);
double min_gain_shift = gain_shift + min_gain_to_split_; double min_gain_shift = gain_shift + tree_config_->min_gain_to_split;
is_splittable_ = false; is_splittable_ = false;
// from right to left, and we don't need data in bin0 // from right to left, and we don't need data in bin0
for (unsigned int t = num_bins_ - 1; t > 0; --t) { for (unsigned int t = num_bins_ - 1; t > 0; --t) {
...@@ -127,18 +131,20 @@ public: ...@@ -127,18 +131,20 @@ public:
sum_right_hessian += data_[t].sum_hessians; sum_right_hessian += data_[t].sum_hessians;
right_count += data_[t].cnt; right_count += data_[t].cnt;
// if data not enough, or sum hessian too small // if data not enough, or sum hessian too small
if (right_count < min_num_data_one_leaf_ || sum_right_hessian < min_sum_hessian_one_leaf_) continue; if (right_count < tree_config_->min_data_in_leaf
|| sum_right_hessian < tree_config_->min_sum_hessian_in_leaf) continue;
data_size_t left_count = num_data_ - right_count; data_size_t left_count = num_data_ - right_count;
// if data not enough // if data not enough
if (left_count < min_num_data_one_leaf_) break; if (left_count < tree_config_->min_data_in_leaf) break;
double sum_left_hessian = sum_hessians_ - sum_right_hessian; double sum_left_hessian = sum_hessians_ - sum_right_hessian;
// if sum hessian too small // if sum hessian too small
if (sum_left_hessian < min_sum_hessian_one_leaf_) break; if (sum_left_hessian < tree_config_->min_sum_hessian_in_leaf) break;
double sum_left_gradient = sum_gradients_ - sum_right_gradient; double sum_left_gradient = sum_gradients_ - sum_right_gradient;
// current split gain // current split gain
double current_gain = GetLeafSplitGain(sum_left_gradient, sum_left_hessian) + GetLeafSplitGain(sum_right_gradient, sum_right_hessian); double current_gain = GetLeafSplitGain(sum_left_gradient, sum_left_hessian)
+ GetLeafSplitGain(sum_right_gradient, sum_right_hessian);
// gain with split is worse than without split // gain with split is worse than without split
if (current_gain < min_gain_shift) continue; if (current_gain < min_gain_shift) continue;
...@@ -154,6 +160,7 @@ public: ...@@ -154,6 +160,7 @@ public:
best_gain = current_gain; best_gain = current_gain;
} }
} }
if (is_splittable_) {
// update split information // update split information
output->feature = feature_idx_; output->feature = feature_idx_;
output->threshold = best_threshold; output->threshold = best_threshold;
...@@ -167,6 +174,75 @@ public: ...@@ -167,6 +174,75 @@ public:
output->right_sum_gradient = sum_gradients_ - best_sum_left_gradient; output->right_sum_gradient = sum_gradients_ - best_sum_left_gradient;
output->right_sum_hessian = sum_hessians_ - best_sum_left_hessian; output->right_sum_hessian = sum_hessians_ - best_sum_left_hessian;
output->gain = best_gain - gain_shift; output->gain = best_gain - gain_shift;
} else {
output->feature = feature_idx_;
output->gain = kMinScore;
}
}
/*!
* \brief Find best threshold for this histogram
* \param output The best split result
*/
void FindBestThresholdForCategorical(SplitInfo* output) {
double best_gain = kMinScore;
unsigned int best_threshold = static_cast<unsigned int>(num_bins_);
double gain_shift = GetLeafSplitGain(sum_gradients_, sum_hessians_);
double min_gain_shift = gain_shift + tree_config_->min_gain_to_split;
is_splittable_ = false;
for (int t = num_bins_ - 1; t >= 0; --t) {
double sum_current_gradient = data_[t].sum_gradients;
double sum_current_hessian = data_[t].sum_hessians;
data_size_t current_count = data_[t].cnt;
// if data not enough, or sum hessian too small
if (current_count < tree_config_->min_data_in_leaf
|| sum_current_hessian < tree_config_->min_sum_hessian_in_leaf) continue;
data_size_t other_count = num_data_ - current_count;
// if data not enough
if (other_count < tree_config_->min_data_in_leaf) continue;
double sum_other_hessian = sum_hessians_ - sum_current_hessian;
// if sum hessian too small
if (sum_other_hessian < tree_config_->min_sum_hessian_in_leaf) continue;
double sum_other_gradient = sum_gradients_ - sum_current_gradient;
// current split gain
double current_gain = GetLeafSplitGain(sum_other_gradient, sum_other_hessian)
+ GetLeafSplitGain(sum_current_gradient, sum_current_hessian);
// gain with split is worse than without split
if (current_gain < min_gain_shift) continue;
// mark to is splittable
is_splittable_ = true;
// better split point
if (current_gain > best_gain) {
best_threshold = static_cast<unsigned int>(t);
best_gain = current_gain;
}
}
// update split information
if (is_splittable_) {
output->feature = feature_idx_;
output->threshold = best_threshold;
output->left_output = CalculateSplittedLeafOutput(data_[best_threshold].sum_gradients,
data_[best_threshold].sum_hessians);
output->left_count = data_[best_threshold].cnt;
output->left_sum_gradient = data_[best_threshold].sum_gradients;
output->left_sum_hessian = data_[best_threshold].sum_hessians;
output->right_output = CalculateSplittedLeafOutput(sum_gradients_ - data_[best_threshold].sum_gradients,
sum_hessians_ - data_[best_threshold].sum_hessians);
output->right_count = num_data_ - data_[best_threshold].cnt;
output->right_sum_gradient = sum_gradients_ - data_[best_threshold].sum_gradients;
output->right_sum_hessian = sum_hessians_ - data_[best_threshold].sum_hessians;
output->gain = best_gain - gain_shift;
} else {
output->feature = feature_idx_;
output->gain = kMinScore;
}
} }
/*! /*!
...@@ -190,20 +266,6 @@ public: ...@@ -190,20 +266,6 @@ public:
std::memcpy(data_.data(), memory_data, num_bins_ * sizeof(HistogramBinEntry)); std::memcpy(data_.data(), memory_data, num_bins_ * sizeof(HistogramBinEntry));
} }
/*!
* \brief Set min number data in one leaf
*/
void SetMinNumDataOneLeaf(data_size_t new_val) {
min_num_data_one_leaf_ = new_val;
}
/*!
* \brief Set min sum hessian in one leaf
*/
void SetMinSumHessianOneLeaf(double new_val) {
min_sum_hessian_one_leaf_ = new_val;
}
/*! /*!
* \brief True if this histogram can be splitted * \brief True if this histogram can be splitted
*/ */
...@@ -223,9 +285,10 @@ private: ...@@ -223,9 +285,10 @@ private:
*/ */
double GetLeafSplitGain(double sum_gradients, double sum_hessians) const { double GetLeafSplitGain(double sum_gradients, double sum_hessians) const {
double abs_sum_gradients = std::fabs(sum_gradients); double abs_sum_gradients = std::fabs(sum_gradients);
if (abs_sum_gradients > lambda_l1_) { if (abs_sum_gradients > tree_config_->lambda_l1) {
double reg_abs_sum_gradients = abs_sum_gradients - lambda_l1_; double reg_abs_sum_gradients = abs_sum_gradients - tree_config_->lambda_l1;
return (reg_abs_sum_gradients * reg_abs_sum_gradients) / (sum_hessians + lambda_l2_); return (reg_abs_sum_gradients * reg_abs_sum_gradients)
/ (sum_hessians + tree_config_->lambda_l2);
} }
return 0.0f; return 0.0f;
} }
...@@ -238,23 +301,16 @@ private: ...@@ -238,23 +301,16 @@ private:
*/ */
double CalculateSplittedLeafOutput(double sum_gradients, double sum_hessians) const { double CalculateSplittedLeafOutput(double sum_gradients, double sum_hessians) const {
double abs_sum_gradients = std::fabs(sum_gradients); double abs_sum_gradients = std::fabs(sum_gradients);
if (abs_sum_gradients > lambda_l1_) { if (abs_sum_gradients > tree_config_->lambda_l1) {
return -std::copysign(abs_sum_gradients - lambda_l1_, sum_gradients) / (sum_hessians + lambda_l2_); return -std::copysign(abs_sum_gradients - tree_config_->lambda_l1, sum_gradients)
/ (sum_hessians + tree_config_->lambda_l2);
} }
return 0.0f; return 0.0f;
} }
int feature_idx_; int feature_idx_;
/*! \brief minimal number of data in one leaf */ /*! \brief pointer of tree config */
data_size_t min_num_data_one_leaf_; const TreeConfig* tree_config_;
/*! \brief minimal sum hessian of data in one leaf */
double min_sum_hessian_one_leaf_;
/*! \brief lambda of the L1 weights regularization */
double lambda_l1_;
/*! \brief lambda of the L2 weights regularization */
double lambda_l2_;
/*! \brief minimal gain (loss reduction) to split */
double min_gain_to_split_;
/*! \brief the bin data of current feature */ /*! \brief the bin data of current feature */
const Bin* bin_data_; const Bin* bin_data_;
/*! \brief number of bin of histogram */ /*! \brief number of bin of histogram */
...@@ -269,6 +325,8 @@ private: ...@@ -269,6 +325,8 @@ private:
double sum_hessians_; double sum_hessians_;
/*! \brief False if this histogram cannot split */ /*! \brief False if this histogram cannot split */
bool is_splittable_ = true; bool is_splittable_ = true;
/*! \brief function that used to find best threshold */
std::function<void(SplitInfo*)> find_best_threshold_fun_;
}; };
......
...@@ -7,18 +7,9 @@ ...@@ -7,18 +7,9 @@
namespace LightGBM { namespace LightGBM {
SerialTreeLearner::SerialTreeLearner(const TreeConfig& tree_config) { SerialTreeLearner::SerialTreeLearner(const TreeConfig& tree_config)
// initialize with nullptr :tree_config_(tree_config){
num_leaves_ = tree_config.num_leaves;
min_num_data_one_leaf_ = static_cast<data_size_t>(tree_config.min_data_in_leaf);
min_sum_hessian_one_leaf_ = static_cast<double>(tree_config.min_sum_hessian_in_leaf);
lambda_l1_ = tree_config.lambda_l1;
lambda_l2_ = tree_config.lambda_l2;
min_gain_to_split_ = tree_config.min_gain_to_split;
feature_fraction_ = tree_config.feature_fraction;
random_ = Random(tree_config.feature_fraction_seed); random_ = Random(tree_config.feature_fraction_seed);
histogram_pool_size_ = tree_config.histogram_pool_size;
max_depth_ = tree_config.max_depth;
} }
SerialTreeLearner::~SerialTreeLearner() { SerialTreeLearner::~SerialTreeLearner() {
...@@ -31,36 +22,32 @@ void SerialTreeLearner::Init(const Dataset* train_data) { ...@@ -31,36 +22,32 @@ void SerialTreeLearner::Init(const Dataset* train_data) {
num_features_ = train_data_->num_features(); num_features_ = train_data_->num_features();
int max_cache_size = 0; int max_cache_size = 0;
// Get the max size of pool // Get the max size of pool
if (histogram_pool_size_ < 0) { if (tree_config_.histogram_pool_size < 0) {
max_cache_size = num_leaves_; max_cache_size = tree_config_.num_leaves;
} else { } else {
size_t total_histogram_size = 0; size_t total_histogram_size = 0;
for (int i = 0; i < train_data_->num_features(); ++i) { for (int i = 0; i < train_data_->num_features(); ++i) {
total_histogram_size += sizeof(HistogramBinEntry) * train_data_->FeatureAt(i)->num_bin(); total_histogram_size += sizeof(HistogramBinEntry) * train_data_->FeatureAt(i)->num_bin();
} }
max_cache_size = static_cast<int>(histogram_pool_size_ * 1024 * 1024 / total_histogram_size); max_cache_size = static_cast<int>(tree_config_.histogram_pool_size * 1024 * 1024 / total_histogram_size);
} }
// at least need 2 leaves // at least need 2 leaves
max_cache_size = std::max(2, max_cache_size); max_cache_size = std::max(2, max_cache_size);
max_cache_size = std::min(max_cache_size, num_leaves_); max_cache_size = std::min(max_cache_size, tree_config_.num_leaves);
histogram_pool_.ResetSize(max_cache_size, num_leaves_); histogram_pool_.ResetSize(max_cache_size, tree_config_.num_leaves);
auto histogram_create_function = [this]() { auto histogram_create_function = [this]() {
auto tmp_histogram_array = std::unique_ptr<FeatureHistogram[]>(new FeatureHistogram[train_data_->num_features()]); auto tmp_histogram_array = std::unique_ptr<FeatureHistogram[]>(new FeatureHistogram[train_data_->num_features()]);
for (int j = 0; j < train_data_->num_features(); ++j) { for (int j = 0; j < train_data_->num_features(); ++j) {
tmp_histogram_array[j].Init(train_data_->FeatureAt(j), tmp_histogram_array[j].Init(train_data_->FeatureAt(j),
j, min_num_data_one_leaf_, j, &tree_config_);
min_sum_hessian_one_leaf_,
lambda_l1_,
lambda_l2_,
min_gain_to_split_);
} }
return tmp_histogram_array.release(); return tmp_histogram_array.release();
}; };
histogram_pool_.Fill(histogram_create_function); histogram_pool_.Fill(histogram_create_function);
// push split information for all leaves // push split information for all leaves
best_split_per_leaf_.resize(num_leaves_); best_split_per_leaf_.resize(tree_config_.num_leaves);
// initialize ordered_bins_ with nullptr // initialize ordered_bins_ with nullptr
ordered_bins_.resize(num_features_); ordered_bins_.resize(num_features_);
...@@ -82,7 +69,7 @@ void SerialTreeLearner::Init(const Dataset* train_data) { ...@@ -82,7 +69,7 @@ void SerialTreeLearner::Init(const Dataset* train_data) {
larger_leaf_splits_.reset(new LeafSplits(train_data_->num_features(), train_data_->num_data())); larger_leaf_splits_.reset(new LeafSplits(train_data_->num_features(), train_data_->num_data()));
// initialize data partition // initialize data partition
data_partition_.reset(new DataPartition(num_data_, num_leaves_)); data_partition_.reset(new DataPartition(num_data_, tree_config_.num_leaves));
is_feature_used_.resize(num_features_); is_feature_used_.resize(num_features_);
...@@ -102,14 +89,14 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians ...@@ -102,14 +89,14 @@ Tree* SerialTreeLearner::Train(const score_t* gradients, const score_t *hessians
hessians_ = hessians; hessians_ = hessians;
// some initial works before training // some initial works before training
BeforeTrain(); BeforeTrain();
auto tree = std::unique_ptr<Tree>(new Tree(num_leaves_)); auto tree = std::unique_ptr<Tree>(new Tree(tree_config_.num_leaves));
// save pointer to last trained tree // save pointer to last trained tree
last_trained_tree_ = tree.get(); last_trained_tree_ = tree.get();
// root leaf // root leaf
int left_leaf = 0; int left_leaf = 0;
// only root leaf can be splitted on first time // only root leaf can be splitted on first time
int right_leaf = -1; int right_leaf = -1;
for (int split = 0; split < num_leaves_ - 1; split++) { for (int split = 0; split < tree_config_.num_leaves - 1; split++) {
// some initial works before finding best split // some initial works before finding best split
if (BeforeFindBestSplit(left_leaf, right_leaf)) { if (BeforeFindBestSplit(left_leaf, right_leaf)) {
// find best threshold for every feature // find best threshold for every feature
...@@ -141,7 +128,7 @@ void SerialTreeLearner::BeforeTrain() { ...@@ -141,7 +128,7 @@ void SerialTreeLearner::BeforeTrain() {
is_feature_used_[i] = false; is_feature_used_[i] = false;
} }
// Get used feature at current tree // Get used feature at current tree
int used_feature_cnt = static_cast<int>(num_features_*feature_fraction_); int used_feature_cnt = static_cast<int>(num_features_*tree_config_.feature_fraction);
auto used_feature_indices = random_.Sample(num_features_, used_feature_cnt); auto used_feature_indices = random_.Sample(num_features_, used_feature_cnt);
for (auto idx : used_feature_indices) { for (auto idx : used_feature_indices) {
is_feature_used_[idx] = true; is_feature_used_[idx] = true;
...@@ -151,7 +138,7 @@ void SerialTreeLearner::BeforeTrain() { ...@@ -151,7 +138,7 @@ void SerialTreeLearner::BeforeTrain() {
data_partition_->Init(); data_partition_->Init();
// reset the splits for leaves // reset the splits for leaves
for (int i = 0; i < num_leaves_; ++i) { for (int i = 0; i < tree_config_.num_leaves; ++i) {
best_split_per_leaf_[i].Reset(); best_split_per_leaf_[i].Reset();
} }
...@@ -190,7 +177,7 @@ void SerialTreeLearner::BeforeTrain() { ...@@ -190,7 +177,7 @@ void SerialTreeLearner::BeforeTrain() {
#pragma omp parallel for schedule(guided) #pragma omp parallel for schedule(guided)
for (int i = 0; i < num_features_; ++i) { for (int i = 0; i < num_features_; ++i) {
if (ordered_bins_[i] != nullptr) { if (ordered_bins_[i] != nullptr) {
ordered_bins_[i]->Init(nullptr, num_leaves_); ordered_bins_[i]->Init(nullptr, tree_config_.num_leaves);
} }
} }
} else { } else {
...@@ -209,7 +196,7 @@ void SerialTreeLearner::BeforeTrain() { ...@@ -209,7 +196,7 @@ void SerialTreeLearner::BeforeTrain() {
#pragma omp parallel for schedule(guided) #pragma omp parallel for schedule(guided)
for (int i = 0; i < num_features_; ++i) { for (int i = 0; i < num_features_; ++i) {
if (ordered_bins_[i] != nullptr) { if (ordered_bins_[i] != nullptr) {
ordered_bins_[i]->Init(is_data_in_leaf_.data(), num_leaves_); ordered_bins_[i]->Init(is_data_in_leaf_.data(), tree_config_.num_leaves);
} }
} }
} }
...@@ -218,9 +205,9 @@ void SerialTreeLearner::BeforeTrain() { ...@@ -218,9 +205,9 @@ void SerialTreeLearner::BeforeTrain() {
bool SerialTreeLearner::BeforeFindBestSplit(int left_leaf, int right_leaf) { bool SerialTreeLearner::BeforeFindBestSplit(int left_leaf, int right_leaf) {
// check depth of current leaf // check depth of current leaf
if (max_depth_ > 0) { if (tree_config_.max_depth > 0) {
// only need to check left leaf, since right leaf is in same level of left leaf // only need to check left leaf, since right leaf is in same level of left leaf
if (last_trained_tree_->leaf_depth(left_leaf) >= max_depth_) { if (last_trained_tree_->leaf_depth(left_leaf) >= tree_config_.max_depth) {
best_split_per_leaf_[left_leaf].gain = kMinScore; best_split_per_leaf_[left_leaf].gain = kMinScore;
if (right_leaf >= 0) { if (right_leaf >= 0) {
best_split_per_leaf_[right_leaf].gain = kMinScore; best_split_per_leaf_[right_leaf].gain = kMinScore;
...@@ -231,8 +218,8 @@ bool SerialTreeLearner::BeforeFindBestSplit(int left_leaf, int right_leaf) { ...@@ -231,8 +218,8 @@ bool SerialTreeLearner::BeforeFindBestSplit(int left_leaf, int right_leaf) {
data_size_t num_data_in_left_child = GetGlobalDataCountInLeaf(left_leaf); data_size_t num_data_in_left_child = GetGlobalDataCountInLeaf(left_leaf);
data_size_t num_data_in_right_child = GetGlobalDataCountInLeaf(right_leaf); data_size_t num_data_in_right_child = GetGlobalDataCountInLeaf(right_leaf);
// no enough data to continue // no enough data to continue
if (num_data_in_right_child < static_cast<data_size_t>(min_num_data_one_leaf_ * 2) if (num_data_in_right_child < static_cast<data_size_t>(tree_config_.min_data_in_leaf * 2)
&& num_data_in_left_child < static_cast<data_size_t>(min_num_data_one_leaf_ * 2)) { && num_data_in_left_child < static_cast<data_size_t>(tree_config_.min_data_in_leaf * 2)) {
best_split_per_leaf_[left_leaf].gain = kMinScore; best_split_per_leaf_[left_leaf].gain = kMinScore;
if (right_leaf >= 0) { if (right_leaf >= 0) {
best_split_per_leaf_[right_leaf].gain = kMinScore; best_split_per_leaf_[right_leaf].gain = kMinScore;
...@@ -393,11 +380,15 @@ void SerialTreeLearner::Split(Tree* tree, int best_Leaf, int* left_leaf, int* ri ...@@ -393,11 +380,15 @@ void SerialTreeLearner::Split(Tree* tree, int best_Leaf, int* left_leaf, int* ri
// left = parent // left = parent
*left_leaf = best_Leaf; *left_leaf = best_Leaf;
// split tree, will return right leaf // split tree, will return right leaf
*right_leaf = tree->Split(best_Leaf, best_split_info.feature, best_split_info.threshold, *right_leaf = tree->Split(best_Leaf, best_split_info.feature,
train_data_->FeatureAt(best_split_info.feature)->bin_type(),
best_split_info.threshold,
train_data_->FeatureAt(best_split_info.feature)->feature_index(), train_data_->FeatureAt(best_split_info.feature)->feature_index(),
train_data_->FeatureAt(best_split_info.feature)->BinToValue(best_split_info.threshold), train_data_->FeatureAt(best_split_info.feature)->BinToValue(best_split_info.threshold),
static_cast<double>(best_split_info.left_output), static_cast<double>(best_split_info.left_output),
static_cast<double>(best_split_info.right_output), static_cast<double>(best_split_info.right_output),
static_cast<data_size_t>(best_split_info.left_count),
static_cast<data_size_t>(best_split_info.right_count),
static_cast<double>(best_split_info.gain)); static_cast<double>(best_split_info.gain));
// split data partition // split data partition
......
...@@ -109,20 +109,6 @@ protected: ...@@ -109,20 +109,6 @@ protected:
const score_t* gradients_; const score_t* gradients_;
/*! \brief hessians of current iteration */ /*! \brief hessians of current iteration */
const score_t* hessians_; const score_t* hessians_;
/*! \brief number of total leaves */
int num_leaves_;
/*! \brief minimal data on one leaf */
data_size_t min_num_data_one_leaf_;
/*! \brief minimal sum hessian on one leaf */
double min_sum_hessian_one_leaf_;
/*! \brief lambda of the L1 weights regularization */
double lambda_l1_;
/*! \brief lambda of the L2 weights regularization */
double lambda_l2_;
/*! \brief minimal gain (loss reduction) to split */
double min_gain_to_split_;
/*! \brief sub-feature fraction rate */
double feature_fraction_;
/*! \brief training data partition on leaves */ /*! \brief training data partition on leaves */
std::unique_ptr<DataPartition> data_partition_; std::unique_ptr<DataPartition> data_partition_;
/*! \brief used for generate used features */ /*! \brief used for generate used features */
...@@ -158,19 +144,16 @@ protected: ...@@ -158,19 +144,16 @@ protected:
const score_t* ptr_to_ordered_gradients_larger_leaf_; const score_t* ptr_to_ordered_gradients_larger_leaf_;
/*! \brief Pointer to ordered_hessians_, use this to avoid copy at BeforeTrain*/ /*! \brief Pointer to ordered_hessians_, use this to avoid copy at BeforeTrain*/
const score_t* ptr_to_ordered_hessians_larger_leaf_; const score_t* ptr_to_ordered_hessians_larger_leaf_;
/*! \brief Store ordered bin */ /*! \brief Store ordered bin */
std::vector<std::unique_ptr<OrderedBin>> ordered_bins_; std::vector<std::unique_ptr<OrderedBin>> ordered_bins_;
/*! \brief True if has ordered bin */ /*! \brief True if has ordered bin */
bool has_ordered_bin_ = false; bool has_ordered_bin_ = false;
/*! \brief is_data_in_leaf_[i] != 0 means i-th data is marked */ /*! \brief is_data_in_leaf_[i] != 0 means i-th data is marked */
std::vector<char> is_data_in_leaf_; std::vector<char> is_data_in_leaf_;
/*! \brief max cache size(unit:GB) for historical histogram. < 0 means not limit */
double histogram_pool_size_;
/*! \brief used to cache historical histogram to speed up*/ /*! \brief used to cache historical histogram to speed up*/
HistogramPool histogram_pool_; HistogramPool histogram_pool_;
/*! \brief max depth of tree model */ /*! \brief config of tree learner*/
int max_depth_; const TreeConfig& tree_config_;
}; };
......
...@@ -7,6 +7,11 @@ X, Y = datasets.make_classification(n_samples=100000, n_features=100) ...@@ -7,6 +7,11 @@ X, Y = datasets.make_classification(n_samples=100000, n_features=100)
x_train, x_test, y_train, y_test = model_selection.train_test_split(X, Y, test_size=0.1) x_train, x_test, y_train, y_test = model_selection.train_test_split(X, Y, test_size=0.1)
train_data = lgb.Dataset(x_train, max_bin=255, label=y_train) train_data = lgb.Dataset(x_train, max_bin=255, label=y_train)
num_features = train_data.num_feature()
names = ["name_%d" %(i) for i in range(num_features)]
train_data.set_feature_name(names)
valid_data = train_data.create_valid(x_test, label=y_test) valid_data = train_data.create_valid(x_test, label=y_test)
config={"objective":"binary","metric":"auc", "min_data":1, "num_leaves":15} config={"objective":"binary","metric":"auc", "min_data":1, "num_leaves":15}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment