Unverified Commit e005cdb0 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

Monotone Constraint (#1314)

parent 45adbf89
...@@ -97,7 +97,7 @@ if(USE_HDFS) ...@@ -97,7 +97,7 @@ if(USE_HDFS)
endif(USE_HDFS) endif(USE_HDFS)
if(UNIX OR MINGW OR CYGWIN) if(UNIX OR MINGW OR CYGWIN)
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -pthread -O3 -Wextra -Wall -Wno-ignored-attributes -Wno-unknown-pragmas") SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -pthread -O3 -Wextra -Wall -Wno-ignored-attributes -Wno-unknown-pragmas -Wno-return-type")
endif() endif()
if(WIN32 AND MINGW) if(WIN32 AND MINGW)
......
...@@ -313,6 +313,14 @@ Learning Control Parameters ...@@ -313,6 +313,14 @@ Learning Control Parameters
- set this to larger value for more accurate result, but it will slow down the training speed - set this to larger value for more accurate result, but it will slow down the training speed
- ``monotone_constraint``, default=``None``, type=multi-int, alias=\ ``mc``
- used for constraints of monotonic features
- ``1`` means increasing, ``-1`` means decreasing, ``0`` means non-constraint
- need to specific all features in order. For example, ``mc=-1,0,1`` means the decreasing for 1st feature, non-constraint for 2nd feature and increasing for the 3rd feature.
IO Parameters IO Parameters
------------- -------------
......
...@@ -126,6 +126,7 @@ public: ...@@ -126,6 +126,7 @@ public:
double max_conflict_rate = 0.0f; double max_conflict_rate = 0.0f;
bool enable_bundle = true; bool enable_bundle = true;
bool has_header = false; bool has_header = false;
std::vector<int8_t> monotone_constraints;
/*! \brief Index or column name of label, default is the first column /*! \brief Index or column name of label, default is the first column
* And add an prefix "name:" while using column name */ * And add an prefix "name:" while using column name */
std::string label_column = ""; std::string label_column = "";
...@@ -444,7 +445,8 @@ struct ParameterAlias { ...@@ -444,7 +445,8 @@ struct ParameterAlias {
{ "workers", "machines" }, { "workers", "machines" },
{ "nodes", "machines" }, { "nodes", "machines" },
{ "subsample_for_bin", "bin_construct_sample_cnt" }, { "subsample_for_bin", "bin_construct_sample_cnt" },
{ "metric_freq", "output_freq" } { "metric_freq", "output_freq" },
{ "mc", "monotone_constraints" }
}); });
const std::unordered_set<std::string> parameter_set({ const std::unordered_set<std::string> parameter_set({
"config", "config_file", "task", "device", "config", "config_file", "task", "device",
...@@ -477,7 +479,7 @@ struct ParameterAlias { ...@@ -477,7 +479,7 @@ struct ParameterAlias {
"histogram_pool_size", "is_provide_training_metric", "machine_list_filename", "machines", "histogram_pool_size", "is_provide_training_metric", "machine_list_filename", "machines",
"zero_as_missing", "init_score_file", "valid_init_score_file", "is_predict_contrib", "zero_as_missing", "init_score_file", "valid_init_score_file", "is_predict_contrib",
"max_cat_threshold", "cat_smooth", "min_data_per_group", "cat_l2", "max_cat_to_onehot", "max_cat_threshold", "cat_smooth", "min_data_per_group", "cat_l2", "max_cat_to_onehot",
"alpha", "reg_sqrt", "tweedie_variance_power" "alpha", "reg_sqrt", "tweedie_variance_power", "monotone_constraints"
}); });
std::unordered_map<std::string, std::string> tmp_map; std::unordered_map<std::string, std::string> tmp_map;
for (const auto& pair : *params) { for (const auto& pair : *params) {
......
...@@ -434,6 +434,27 @@ public: ...@@ -434,6 +434,27 @@ public:
const int sub_feature = feature2subfeature_[i]; const int sub_feature = feature2subfeature_[i];
return feature_groups_[group]->bin_mappers_[sub_feature]->num_bin(); return feature_groups_[group]->bin_mappers_[sub_feature]->num_bin();
} }
inline int8_t FeatureMonotone(int i) const {
if (monotone_types_.empty()) {
return 0;
} else {
return monotone_types_[i];
}
}
bool HasMonotone() const {
if (monotone_types_.empty()) {
return false;
} else {
for (size_t i = 0; i < monotone_types_.size(); ++i) {
if (monotone_types_[i] != 0) {
return true;
}
}
return false;
}
}
inline int FeatureGroupNumBin(int group) const { inline int FeatureGroupNumBin(int group) const {
return feature_groups_[group]->num_total_bin_; return feature_groups_[group]->num_total_bin_;
...@@ -576,6 +597,7 @@ private: ...@@ -576,6 +597,7 @@ private:
std::vector<uint64_t> group_bin_boundaries_; std::vector<uint64_t> group_bin_boundaries_;
std::vector<int> group_feature_start_; std::vector<int> group_feature_start_;
std::vector<int> group_feature_cnt_; std::vector<int> group_feature_cnt_;
std::vector<int8_t> monotone_types_;
bool is_finish_load_; bool is_finish_load_;
}; };
......
...@@ -160,6 +160,22 @@ public: ...@@ -160,6 +160,22 @@ public:
out->erase(out->begin() + k, out->end()); out->erase(out->begin() + k, out->end());
} }
inline static void Assign(std::vector<VAL_T>* array, VAL_T t, size_t n) {
array->resize(n);
for (size_t i = 0; i < array->size(); ++i) {
(*array)[i] = t;
}
}
inline static bool CheckAllZero(const std::vector<VAL_T>& array) {
for (size_t i = 0; i < array.size(); ++i) {
if (array[i] != VAL_T(0)) {
return false;
}
}
return true;
}
}; };
} // namespace LightGBM } // namespace LightGBM
......
...@@ -292,6 +292,9 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) { ...@@ -292,6 +292,9 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetString(params, "convert_model", &convert_model); GetString(params, "convert_model", &convert_model);
GetString(params, "output_result", &output_result); GetString(params, "output_result", &output_result);
std::string tmp_str = ""; std::string tmp_str = "";
if (GetString(params, "monotone_constraints", &tmp_str)) {
monotone_constraints = Common::StringToArray<int8_t>(tmp_str.c_str(), ',');
}
if (GetString(params, "valid_data", &tmp_str)) { if (GetString(params, "valid_data", &tmp_str)) {
valid_data_filenames = Common::Split(tmp_str.c_str(), ','); valid_data_filenames = Common::Split(tmp_str.c_str(), ',');
} }
......
...@@ -292,6 +292,20 @@ void Dataset::Construct( ...@@ -292,6 +292,20 @@ void Dataset::Construct(
last_group = group; last_group = group;
} }
} }
if (!io_config.monotone_constraints.empty()) {
CHECK(static_cast<size_t>(num_total_features_) == io_config.monotone_constraints.size());
monotone_types_.resize(num_features_);
for (int i = 0; i < num_total_features_; ++i) {
int inner_fidx = InnerFeatureIndex(i);
if (inner_fidx >= 0) {
monotone_types_[inner_fidx] = io_config.monotone_constraints[i];
}
}
if (ArrayArgs<int8_t>::CheckAllZero(monotone_types_)) {
monotone_types_.clear();
}
}
} }
void Dataset::FinishLoad() { void Dataset::FinishLoad() {
...@@ -335,6 +349,7 @@ void Dataset::CopyFeatureMapperFrom(const Dataset* dataset) { ...@@ -335,6 +349,7 @@ void Dataset::CopyFeatureMapperFrom(const Dataset* dataset) {
group_bin_boundaries_ = dataset->group_bin_boundaries_; group_bin_boundaries_ = dataset->group_bin_boundaries_;
group_feature_start_ = dataset->group_feature_start_; group_feature_start_ = dataset->group_feature_start_;
group_feature_cnt_ = dataset->group_feature_cnt_; group_feature_cnt_ = dataset->group_feature_cnt_;
monotone_types_ = dataset->monotone_types_;
} }
void Dataset::CreateValid(const Dataset* dataset) { void Dataset::CreateValid(const Dataset* dataset) {
...@@ -387,6 +402,7 @@ void Dataset::CreateValid(const Dataset* dataset) { ...@@ -387,6 +402,7 @@ void Dataset::CreateValid(const Dataset* dataset) {
last_group = group; last_group = group;
} }
} }
monotone_types_ = dataset->monotone_types_;
} }
void Dataset::ReSize(data_size_t num_data) { void Dataset::ReSize(data_size_t num_data) {
...@@ -539,7 +555,7 @@ void Dataset::SaveBinaryFile(const char* bin_filename) { ...@@ -539,7 +555,7 @@ void Dataset::SaveBinaryFile(const char* bin_filename) {
// get size of header // get size of header
size_t size_of_header = sizeof(num_data_) + sizeof(num_features_) + sizeof(num_total_features_) size_t size_of_header = sizeof(num_data_) + sizeof(num_features_) + sizeof(num_total_features_)
+ sizeof(int) * num_total_features_ + sizeof(label_idx_) + sizeof(num_groups_) + sizeof(int) * num_total_features_ + sizeof(label_idx_) + sizeof(num_groups_)
+ 3 * sizeof(int) * num_features_ + sizeof(uint64_t) * (num_groups_ + 1) + 2 * sizeof(int) * num_groups_; + 3 * sizeof(int) * num_features_ + sizeof(uint64_t) * (num_groups_ + 1) + 2 * sizeof(int) * num_groups_ + sizeof(int8_t) * num_features_;
// size of feature names // size of feature names
for (int i = 0; i < num_total_features_; ++i) { for (int i = 0; i < num_total_features_; ++i) {
size_of_header += feature_names_[i].size() + sizeof(int); size_of_header += feature_names_[i].size() + sizeof(int);
...@@ -558,7 +574,13 @@ void Dataset::SaveBinaryFile(const char* bin_filename) { ...@@ -558,7 +574,13 @@ void Dataset::SaveBinaryFile(const char* bin_filename) {
writer->Write(group_bin_boundaries_.data(), sizeof(uint64_t) * (num_groups_ + 1)); writer->Write(group_bin_boundaries_.data(), sizeof(uint64_t) * (num_groups_ + 1));
writer->Write(group_feature_start_.data(), sizeof(int) * num_groups_); writer->Write(group_feature_start_.data(), sizeof(int) * num_groups_);
writer->Write(group_feature_cnt_.data(), sizeof(int) * num_groups_); writer->Write(group_feature_cnt_.data(), sizeof(int) * num_groups_);
if (monotone_types_.empty()) {
ArrayArgs<int8_t>::Assign(&monotone_types_, 0, num_features_);
}
writer->Write(monotone_types_.data(), sizeof(int8_t) * num_features_);
if (ArrayArgs<int8_t>::CheckAllZero(monotone_types_)) {
monotone_types_.clear();
}
// write feature names // write feature names
for (int i = 0; i < num_total_features_; ++i) { for (int i = 0; i < num_total_features_; ++i) {
int str_len = static_cast<int>(feature_names_[i].size()); int str_len = static_cast<int>(feature_names_[i].size());
......
#include <LightGBM/utils/openmp_wrapper.h> #include <LightGBM/utils/openmp_wrapper.h>
#include <LightGBM/utils/log.h> #include <LightGBM/utils/log.h>
#include <LightGBM/utils/array_args.h>
#include <LightGBM/dataset_loader.h> #include <LightGBM/dataset_loader.h>
#include <LightGBM/network.h> #include <LightGBM/network.h>
...@@ -368,6 +369,17 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b ...@@ -368,6 +369,17 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
} }
mem_ptr += sizeof(int) * (dataset->num_groups_); mem_ptr += sizeof(int) * (dataset->num_groups_);
const int8_t* tmp_ptr_monotone_type = reinterpret_cast<const int8_t*>(mem_ptr);
dataset->monotone_types_.clear();
for (int i = 0; i < dataset->num_features_; ++i) {
dataset->monotone_types_.push_back(tmp_ptr_monotone_type[i]);
}
mem_ptr += sizeof(int8_t) * (dataset->num_features_);
if (ArrayArgs<int8_t>::CheckAllZero(dataset->monotone_types_)) {
dataset->monotone_types_.clear();
}
// get feature names // get feature names
dataset->feature_names_.clear(); dataset->feature_names_.clear();
// write feature names // write feature names
......
...@@ -187,6 +187,8 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const ...@@ -187,6 +187,8 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const
this->smaller_leaf_splits_->sum_gradients(), this->smaller_leaf_splits_->sum_gradients(),
this->smaller_leaf_splits_->sum_hessians(), this->smaller_leaf_splits_->sum_hessians(),
GetGlobalDataCountInLeaf(this->smaller_leaf_splits_->LeafIndex()), GetGlobalDataCountInLeaf(this->smaller_leaf_splits_->LeafIndex()),
this->smaller_leaf_splits_->min_constraint(),
this->smaller_leaf_splits_->max_constraint(),
&smaller_split); &smaller_split);
smaller_split.feature = real_feature_index; smaller_split.feature = real_feature_index;
if (smaller_split > smaller_bests_per_thread[tid]) { if (smaller_split > smaller_bests_per_thread[tid]) {
...@@ -205,6 +207,8 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const ...@@ -205,6 +207,8 @@ void DataParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(const
this->larger_leaf_splits_->sum_gradients(), this->larger_leaf_splits_->sum_gradients(),
this->larger_leaf_splits_->sum_hessians(), this->larger_leaf_splits_->sum_hessians(),
GetGlobalDataCountInLeaf(this->larger_leaf_splits_->LeafIndex()), GetGlobalDataCountInLeaf(this->larger_leaf_splits_->LeafIndex()),
this->larger_leaf_splits_->min_constraint(),
this->larger_leaf_splits_->max_constraint(),
&larger_split); &larger_split);
larger_split.feature = real_feature_index; larger_split.feature = real_feature_index;
if (larger_split > larger_bests_per_thread[tid]) { if (larger_split > larger_bests_per_thread[tid]) {
......
...@@ -17,6 +17,7 @@ public: ...@@ -17,6 +17,7 @@ public:
MissingType missing_type; MissingType missing_type;
int8_t bias = 0; int8_t bias = 0;
uint32_t default_bin; uint32_t default_bin;
int8_t monotone_type;
/*! \brief pointer of tree config */ /*! \brief pointer of tree config */
const TreeConfig* tree_config; const TreeConfig* tree_config;
}; };
...@@ -47,10 +48,10 @@ public: ...@@ -47,10 +48,10 @@ public:
data_ = data; data_ = data;
if (bin_type == BinType::NumericalBin) { if (bin_type == BinType::NumericalBin) {
find_best_threshold_fun_ = std::bind(&FeatureHistogram::FindBestThresholdNumerical, this, std::placeholders::_1 find_best_threshold_fun_ = std::bind(&FeatureHistogram::FindBestThresholdNumerical, this, std::placeholders::_1
, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4); , std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, std::placeholders::_6);
} else { } else {
find_best_threshold_fun_ = std::bind(&FeatureHistogram::FindBestThresholdCategorical, this, std::placeholders::_1 find_best_threshold_fun_ = std::bind(&FeatureHistogram::FindBestThresholdCategorical, this, std::placeholders::_1
, std::placeholders::_2, std::placeholders::_3, std::placeholders::_4); , std::placeholders::_2, std::placeholders::_3, std::placeholders::_4, std::placeholders::_5, std::placeholders::_6);
} }
} }
...@@ -69,14 +70,14 @@ public: ...@@ -69,14 +70,14 @@ public:
} }
} }
void FindBestThreshold(double sum_gradient, double sum_hessian, data_size_t num_data, void FindBestThreshold(double sum_gradient, double sum_hessian, data_size_t num_data, double min_constraint, double max_constraint,
SplitInfo* output) { SplitInfo* output) {
output->default_left = true; output->default_left = true;
output->gain = kMinScore; output->gain = kMinScore;
find_best_threshold_fun_(sum_gradient, sum_hessian + 2 * kEpsilon, num_data, output); find_best_threshold_fun_(sum_gradient, sum_hessian + 2 * kEpsilon, num_data, min_constraint, max_constraint, output);
} }
void FindBestThresholdNumerical(double sum_gradient, double sum_hessian, data_size_t num_data, void FindBestThresholdNumerical(double sum_gradient, double sum_hessian, data_size_t num_data, double min_constraint, double max_constraint,
SplitInfo* output) { SplitInfo* output) {
is_splittable_ = false; is_splittable_ = false;
...@@ -85,23 +86,26 @@ public: ...@@ -85,23 +86,26 @@ public:
double min_gain_shift = gain_shift + meta_->tree_config->min_gain_to_split; double min_gain_shift = gain_shift + meta_->tree_config->min_gain_to_split;
if (meta_->num_bin > 2 && meta_->missing_type != MissingType::None) { if (meta_->num_bin > 2 && meta_->missing_type != MissingType::None) {
if (meta_->missing_type == MissingType::Zero) { if (meta_->missing_type == MissingType::Zero) {
FindBestThresholdSequence(sum_gradient, sum_hessian, num_data, min_gain_shift, output, -1, true, false); FindBestThresholdSequence(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, -1, true, false);
FindBestThresholdSequence(sum_gradient, sum_hessian, num_data, min_gain_shift, output, 1, true, false); FindBestThresholdSequence(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, 1, true, false);
} else { } else {
FindBestThresholdSequence(sum_gradient, sum_hessian, num_data, min_gain_shift, output, -1, false, true); FindBestThresholdSequence(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, -1, false, true);
FindBestThresholdSequence(sum_gradient, sum_hessian, num_data, min_gain_shift, output, 1, false, true); FindBestThresholdSequence(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, 1, false, true);
} }
} else { } else {
FindBestThresholdSequence(sum_gradient, sum_hessian, num_data, min_gain_shift, output, -1, false, false); FindBestThresholdSequence(sum_gradient, sum_hessian, num_data, min_constraint, max_constraint, min_gain_shift, output, -1, false, false);
// fix the direction error when only have 2 bins // fix the direction error when only have 2 bins
if (meta_->missing_type == MissingType::NaN) { if (meta_->missing_type == MissingType::NaN) {
output->default_left = false; output->default_left = false;
} }
} }
output->gain -= min_gain_shift; output->gain -= min_gain_shift;
output->monotone_type = meta_->monotone_type;
output->min_constraint = min_constraint;
output->max_constraint = max_constraint;
} }
void FindBestThresholdCategorical(double sum_gradient, double sum_hessian, data_size_t num_data, void FindBestThresholdCategorical(double sum_gradient, double sum_hessian, data_size_t num_data, double min_constraint, double max_constraint,
SplitInfo* output) { SplitInfo* output) {
output->default_left = false; output->default_left = false;
double best_gain = kMinScore; double best_gain = kMinScore;
...@@ -135,10 +139,9 @@ public: ...@@ -135,10 +139,9 @@ public:
double sum_other_gradient = sum_gradient - data_[t].sum_gradients; double sum_other_gradient = sum_gradient - data_[t].sum_gradients;
// current split gain // current split gain
double current_gain = GetLeafSplitGain(sum_other_gradient, sum_other_hessian, double current_gain = GetSplitGains(sum_other_gradient, sum_other_hessian, data_[t].sum_gradients, data_[t].sum_hessians + kEpsilon,
meta_->tree_config->lambda_l1, l2) meta_->tree_config->lambda_l1, l2,
+ GetLeafSplitGain(data_[t].sum_gradients, data_[t].sum_hessians + kEpsilon, min_constraint, max_constraint, 0);
meta_->tree_config->lambda_l1, l2);
// gain with split is worse than without split // gain with split is worse than without split
if (current_gain <= min_gain_shift) continue; if (current_gain <= min_gain_shift) continue;
...@@ -208,8 +211,9 @@ public: ...@@ -208,8 +211,9 @@ public:
cnt_cur_group = 0; cnt_cur_group = 0;
double sum_right_gradient = sum_gradient - sum_left_gradient; double sum_right_gradient = sum_gradient - sum_left_gradient;
double current_gain = GetLeafSplitGain(sum_left_gradient, sum_left_hessian, meta_->tree_config->lambda_l1, l2) double current_gain = GetSplitGains(sum_left_gradient, sum_left_hessian, sum_right_gradient, sum_right_hessian,
+ GetLeafSplitGain(sum_right_gradient, sum_right_hessian, meta_->tree_config->lambda_l1, l2); meta_->tree_config->lambda_l1, l2,
min_constraint, max_constraint, 0);
if (current_gain <= min_gain_shift) continue; if (current_gain <= min_gain_shift) continue;
is_splittable_ = true; is_splittable_ = true;
if (current_gain > best_gain) { if (current_gain > best_gain) {
...@@ -226,13 +230,13 @@ public: ...@@ -226,13 +230,13 @@ public:
if (is_splittable_) { if (is_splittable_) {
output->left_output = CalculateSplittedLeafOutput(best_sum_left_gradient, best_sum_left_hessian, output->left_output = CalculateSplittedLeafOutput(best_sum_left_gradient, best_sum_left_hessian,
meta_->tree_config->lambda_l1, l2); meta_->tree_config->lambda_l1, l2, min_constraint, max_constraint);
output->left_count = best_left_count; output->left_count = best_left_count;
output->left_sum_gradient = best_sum_left_gradient; output->left_sum_gradient = best_sum_left_gradient;
output->left_sum_hessian = best_sum_left_hessian - kEpsilon; output->left_sum_hessian = best_sum_left_hessian - kEpsilon;
output->right_output = CalculateSplittedLeafOutput(sum_gradient - best_sum_left_gradient, output->right_output = CalculateSplittedLeafOutput(sum_gradient - best_sum_left_gradient,
sum_hessian - best_sum_left_hessian, sum_hessian - best_sum_left_hessian,
meta_->tree_config->lambda_l1, l2); meta_->tree_config->lambda_l1, l2, min_constraint, max_constraint);
output->right_count = num_data - best_left_count; output->right_count = num_data - best_left_count;
output->right_sum_gradient = sum_gradient - best_sum_left_gradient; output->right_sum_gradient = sum_gradient - best_sum_left_gradient;
output->right_sum_hessian = sum_hessian - best_sum_left_hessian - kEpsilon; output->right_sum_hessian = sum_hessian - best_sum_left_hessian - kEpsilon;
...@@ -255,6 +259,9 @@ public: ...@@ -255,6 +259,9 @@ public:
} }
} }
} }
output->monotone_type = 0;
output->min_constraint = min_constraint;
output->max_constraint = max_constraint;
} }
} }
...@@ -283,17 +290,35 @@ public: ...@@ -283,17 +290,35 @@ public:
void set_is_splittable(bool val) { is_splittable_ = val; } void set_is_splittable(bool val) { is_splittable_ = val; }
/*! /*!
* \brief Calculate the split gain based on regularized sum_gradients and sum_hessians * \brief Calculate the output of a leaf based on regularized sum_gradients and sum_hessians
* \param sum_gradients * \param sum_gradients
* \param sum_hessians * \param sum_hessians
* \return split gain * \return leaf output
*/ */
static double GetLeafSplitGain(double sum_gradients, double sum_hessians, double l1, double l2) { static double CalculateSplittedLeafOutput(double sum_gradients, double sum_hessians, double l1, double l2) {
double abs_sum_gradients = std::fabs(sum_gradients); const double reg_abs_sum_gradients = std::max(0.0, std::fabs(sum_gradients) - l1);
double reg_abs_sum_gradients = std::max(0.0, abs_sum_gradients - l1); return -(Common::Sign(sum_gradients) * reg_abs_sum_gradients) / (sum_hessians + l2);
return (reg_abs_sum_gradients * reg_abs_sum_gradients) }
/ (sum_hessians + l2);
private:
static double GetSplitGains(double sum_left_gradients, double sum_left_hessians,
double sum_right_gradients, double sum_right_hessians,
double l1, double l2,
double min_constraint, double max_constraint, int8_t monotone_constraint) {
double left_output = CalculateSplittedLeafOutput(sum_left_gradients, sum_left_hessians, l1, l2, min_constraint, max_constraint);
double right_output = CalculateSplittedLeafOutput(sum_right_gradients, sum_right_hessians, l1, l2, min_constraint, max_constraint);
if (((monotone_constraint > 0) && (left_output > right_output)) ||
((monotone_constraint < 0) && (left_output < right_output))) {
return 0;
}
return GetLeafSplitGainGivenOutput(sum_left_gradients, sum_left_hessians, l1, l2, left_output)
+ GetLeafSplitGainGivenOutput(sum_right_gradients, sum_right_hessians, l1, l2, right_output);
}
static double ThresholdL1(double s, double l1) {
const double reg_s = std::max(0.0, std::fabs(s) - l1);
return Common::Sign(s) * reg_s;
} }
/*! /*!
...@@ -302,15 +327,37 @@ public: ...@@ -302,15 +327,37 @@ public:
* \param sum_hessians * \param sum_hessians
* \return leaf output * \return leaf output
*/ */
static double CalculateSplittedLeafOutput(double sum_gradients, double sum_hessians, double l1, double l2) { static double CalculateSplittedLeafOutput(double sum_gradients, double sum_hessians, double l1, double l2,
const double reg_abs_sum_gradients = std::max(0.0, std::fabs(sum_gradients) - l1); double min_constraint, double max_constraint) {
return -(Common::Sign(sum_gradients) * reg_abs_sum_gradients) / (sum_hessians + l2); double ret = -ThresholdL1(sum_gradients, l1) / (sum_hessians + l2);
if (ret < min_constraint) {
ret = min_constraint;
} else if (ret > max_constraint) {
ret = max_constraint;
}
return ret;
} }
private: /*!
* \brief Calculate the split gain based on regularized sum_gradients and sum_hessians
* \param sum_gradients
* \param sum_hessians
* \return split gain
*/
static double GetLeafSplitGain(double sum_gradients, double sum_hessians, double l1, double l2) {
double abs_sum_gradients = std::fabs(sum_gradients);
double reg_abs_sum_gradients = std::max(0.0, abs_sum_gradients - l1);
return (reg_abs_sum_gradients * reg_abs_sum_gradients)
/ (sum_hessians + l2);
}
static double GetLeafSplitGainGivenOutput(double sum_gradients, double sum_hessians, double l1, double l2, double output) {
const double sg_l1 = ThresholdL1(sum_gradients, l1);
return -(2.0 * sg_l1 * output + (sum_hessians + l2) * output * output);
}
void FindBestThresholdSequence(double sum_gradient, double sum_hessian, data_size_t num_data, double min_gain_shift, void FindBestThresholdSequence(double sum_gradient, double sum_hessian, data_size_t num_data, double min_constraint, double max_constraint,
SplitInfo* output, int dir, bool skip_default_bin, bool use_na_as_missing) { double min_gain_shift, SplitInfo* output, int dir, bool skip_default_bin, bool use_na_as_missing) {
const int8_t bias = meta_->bias; const int8_t bias = meta_->bias;
...@@ -351,10 +398,9 @@ private: ...@@ -351,10 +398,9 @@ private:
double sum_left_gradient = sum_gradient - sum_right_gradient; double sum_left_gradient = sum_gradient - sum_right_gradient;
// current split gain // current split gain
double current_gain = GetLeafSplitGain(sum_left_gradient, sum_left_hessian, double current_gain = GetSplitGains(sum_left_gradient, sum_left_hessian, sum_right_gradient, sum_right_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2) meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2,
+ GetLeafSplitGain(sum_right_gradient, sum_right_hessian, min_constraint, max_constraint, meta_->monotone_type);
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2);
// gain with split is worse than without split // gain with split is worse than without split
if (current_gain <= min_gain_shift) continue; if (current_gain <= min_gain_shift) continue;
...@@ -412,10 +458,9 @@ private: ...@@ -412,10 +458,9 @@ private:
double sum_right_gradient = sum_gradient - sum_left_gradient; double sum_right_gradient = sum_gradient - sum_left_gradient;
// current split gain // current split gain
double current_gain = GetLeafSplitGain(sum_left_gradient, sum_left_hessian, double current_gain = GetSplitGains(sum_left_gradient, sum_left_hessian, sum_right_gradient, sum_right_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2) meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2,
+ GetLeafSplitGain(sum_right_gradient, sum_right_hessian, min_constraint, max_constraint, meta_->monotone_type);
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2);
// gain with split is worse than without split // gain with split is worse than without split
if (current_gain <= min_gain_shift) continue; if (current_gain <= min_gain_shift) continue;
...@@ -436,13 +481,13 @@ private: ...@@ -436,13 +481,13 @@ private:
// update split information // update split information
output->threshold = best_threshold; output->threshold = best_threshold;
output->left_output = CalculateSplittedLeafOutput(best_sum_left_gradient, best_sum_left_hessian, output->left_output = CalculateSplittedLeafOutput(best_sum_left_gradient, best_sum_left_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2); meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2, min_constraint, max_constraint);
output->left_count = best_left_count; output->left_count = best_left_count;
output->left_sum_gradient = best_sum_left_gradient; output->left_sum_gradient = best_sum_left_gradient;
output->left_sum_hessian = best_sum_left_hessian - kEpsilon; output->left_sum_hessian = best_sum_left_hessian - kEpsilon;
output->right_output = CalculateSplittedLeafOutput(sum_gradient - best_sum_left_gradient, output->right_output = CalculateSplittedLeafOutput(sum_gradient - best_sum_left_gradient,
sum_hessian - best_sum_left_hessian, sum_hessian - best_sum_left_hessian,
meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2); meta_->tree_config->lambda_l1, meta_->tree_config->lambda_l2, min_constraint, max_constraint);
output->right_count = num_data - best_left_count; output->right_count = num_data - best_left_count;
output->right_sum_gradient = sum_gradient - best_sum_left_gradient; output->right_sum_gradient = sum_gradient - best_sum_left_gradient;
output->right_sum_hessian = sum_hessian - best_sum_left_hessian - kEpsilon; output->right_sum_hessian = sum_hessian - best_sum_left_hessian - kEpsilon;
...@@ -458,7 +503,7 @@ private: ...@@ -458,7 +503,7 @@ private:
/*! \brief False if this histogram cannot split */ /*! \brief False if this histogram cannot split */
bool is_splittable_ = true; bool is_splittable_ = true;
std::function<void(double, double, data_size_t, SplitInfo*)> find_best_threshold_fun_; std::function<void(double, double, data_size_t, double, double, SplitInfo*)> find_best_threshold_fun_;
}; };
class HistogramPool { class HistogramPool {
public: public:
...@@ -516,6 +561,7 @@ public: ...@@ -516,6 +561,7 @@ public:
feature_metas_[i].num_bin = train_data->FeatureNumBin(i); feature_metas_[i].num_bin = train_data->FeatureNumBin(i);
feature_metas_[i].default_bin = train_data->FeatureBinMapper(i)->GetDefaultBin(); feature_metas_[i].default_bin = train_data->FeatureBinMapper(i)->GetDefaultBin();
feature_metas_[i].missing_type = train_data->FeatureBinMapper(i)->missing_type(); feature_metas_[i].missing_type = train_data->FeatureBinMapper(i)->missing_type();
feature_metas_[i].monotone_type = train_data->FeatureMonotone(i);
if (train_data->FeatureBinMapper(i)->GetDefaultBin() == 0) { if (train_data->FeatureBinMapper(i)->GetDefaultBin() == 0) {
feature_metas_[i].bias = 1; feature_metas_[i].bias = 1;
} else { } else {
......
...@@ -1106,8 +1106,14 @@ void GPUTreeLearner::Split(Tree* tree, int best_Leaf, int* left_leaf, int* right ...@@ -1106,8 +1106,14 @@ void GPUTreeLearner::Split(Tree* tree, int best_Leaf, int* left_leaf, int* right
Log::Fatal("Bug in GPU histogram! split %d: %d, smaller_leaf: %d, larger_leaf: %d\n", best_split_info.left_count, best_split_info.right_count, smaller_leaf_splits_->num_data_in_leaf(), larger_leaf_splits_->num_data_in_leaf()); Log::Fatal("Bug in GPU histogram! split %d: %d, smaller_leaf: %d, larger_leaf: %d\n", best_split_info.left_count, best_split_info.right_count, smaller_leaf_splits_->num_data_in_leaf(), larger_leaf_splits_->num_data_in_leaf());
} }
} else { } else {
double smaller_min = smaller_leaf_splits_->min_constraint();
double smaller_max = smaller_leaf_splits_->max_constraint();
double larger_min = larger_leaf_splits_->min_constraint();
double larger_max = larger_leaf_splits_->max_constraint();
smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(), best_split_info.right_sum_gradient, best_split_info.right_sum_hessian); smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(), best_split_info.right_sum_gradient, best_split_info.right_sum_hessian);
larger_leaf_splits_->Init(*left_leaf, data_partition_.get(), best_split_info.left_sum_gradient, best_split_info.left_sum_hessian); larger_leaf_splits_->Init(*left_leaf, data_partition_.get(), best_split_info.left_sum_gradient, best_split_info.left_sum_hessian);
smaller_leaf_splits_->SetValueConstraint(smaller_min, smaller_max);
larger_leaf_splits_->SetValueConstraint(larger_min, larger_max);
if ((best_split_info.left_count != larger_leaf_splits_->num_data_in_leaf()) || if ((best_split_info.left_count != larger_leaf_splits_->num_data_in_leaf()) ||
(best_split_info.right_count!= smaller_leaf_splits_->num_data_in_leaf())) { (best_split_info.right_count!= smaller_leaf_splits_->num_data_in_leaf())) {
Log::Fatal("Bug in GPU histogram! split %d: %d, smaller_leaf: %d, larger_leaf: %d\n", best_split_info.left_count, best_split_info.right_count, smaller_leaf_splits_->num_data_in_leaf(), larger_leaf_splits_->num_data_in_leaf()); Log::Fatal("Bug in GPU histogram! split %d: %d, smaller_leaf: %d, larger_leaf: %d\n", best_split_info.left_count, best_split_info.right_count, smaller_leaf_splits_->num_data_in_leaf(), larger_leaf_splits_->num_data_in_leaf());
......
#ifndef LIGHTGBM_TREELEARNER_LEAF_SPLITS_HPP_ #ifndef LIGHTGBM_TREELEARNER_LEAF_SPLITS_HPP_
#define LIGHTGBM_TREELEARNER_LEAF_SPLITS_HPP_ #define LIGHTGBM_TREELEARNER_LEAF_SPLITS_HPP_
#include <limits>
#include <LightGBM/meta.h> #include <LightGBM/meta.h>
#include "data_partition.hpp" #include "data_partition.hpp"
...@@ -37,8 +39,16 @@ public: ...@@ -37,8 +39,16 @@ public:
data_indices_ = data_partition->GetIndexOnLeaf(leaf, &num_data_in_leaf_); data_indices_ = data_partition->GetIndexOnLeaf(leaf, &num_data_in_leaf_);
sum_gradients_ = sum_gradients; sum_gradients_ = sum_gradients;
sum_hessians_ = sum_hessians; sum_hessians_ = sum_hessians;
min_val_ = -std::numeric_limits<double>::max();
max_val_ = std::numeric_limits<double>::max();
}
void SetValueConstraint(double min, double max) {
min_val_ = min;
max_val_ = max;
} }
/*! /*!
* \brief Init splits on current leaf, it will traverse all data to sum up the results * \brief Init splits on current leaf, it will traverse all data to sum up the results
* \param gradients * \param gradients
...@@ -57,6 +67,8 @@ public: ...@@ -57,6 +67,8 @@ public:
} }
sum_gradients_ = tmp_sum_gradients; sum_gradients_ = tmp_sum_gradients;
sum_hessians_ = tmp_sum_hessians; sum_hessians_ = tmp_sum_hessians;
min_val_ = -std::numeric_limits<double>::max();
max_val_ = std::numeric_limits<double>::max();
} }
/*! /*!
...@@ -79,6 +91,8 @@ public: ...@@ -79,6 +91,8 @@ public:
} }
sum_gradients_ = tmp_sum_gradients; sum_gradients_ = tmp_sum_gradients;
sum_hessians_ = tmp_sum_hessians; sum_hessians_ = tmp_sum_hessians;
min_val_ = -std::numeric_limits<double>::max();
max_val_ = std::numeric_limits<double>::max();
} }
...@@ -91,6 +105,8 @@ public: ...@@ -91,6 +105,8 @@ public:
leaf_index_ = 0; leaf_index_ = 0;
sum_gradients_ = sum_gradients; sum_gradients_ = sum_gradients;
sum_hessians_ = sum_hessians; sum_hessians_ = sum_hessians;
min_val_ = -std::numeric_limits<double>::max();
max_val_ = std::numeric_limits<double>::max();
} }
/*! /*!
...@@ -100,6 +116,8 @@ public: ...@@ -100,6 +116,8 @@ public:
leaf_index_ = -1; leaf_index_ = -1;
data_indices_ = nullptr; data_indices_ = nullptr;
num_data_in_leaf_ = 0; num_data_in_leaf_ = 0;
min_val_ = -std::numeric_limits<double>::max();
max_val_ = std::numeric_limits<double>::max();
} }
...@@ -115,6 +133,10 @@ public: ...@@ -115,6 +133,10 @@ public:
/*! \brief Get sum of hessians of current leaf */ /*! \brief Get sum of hessians of current leaf */
double sum_hessians() const { return sum_hessians_; } double sum_hessians() const { return sum_hessians_; }
double max_constraint() const { return max_val_; }
double min_constraint() const { return min_val_; }
/*! \brief Get indices of data of current leaf */ /*! \brief Get indices of data of current leaf */
const data_size_t* data_indices() const { return data_indices_; } const data_size_t* data_indices() const { return data_indices_; }
...@@ -132,6 +154,8 @@ private: ...@@ -132,6 +154,8 @@ private:
double sum_hessians_; double sum_hessians_;
/*! \brief indices of data of current leaf */ /*! \brief indices of data of current leaf */
const data_size_t* data_indices_; const data_size_t* data_indices_;
double min_val_;
double max_val_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
...@@ -480,6 +480,8 @@ void SerialTreeLearner::FindBestSplitsFromHistograms(const std::vector<int8_t>& ...@@ -480,6 +480,8 @@ void SerialTreeLearner::FindBestSplitsFromHistograms(const std::vector<int8_t>&
smaller_leaf_splits_->sum_gradients(), smaller_leaf_splits_->sum_gradients(),
smaller_leaf_splits_->sum_hessians(), smaller_leaf_splits_->sum_hessians(),
smaller_leaf_splits_->num_data_in_leaf(), smaller_leaf_splits_->num_data_in_leaf(),
smaller_leaf_splits_->min_constraint(),
smaller_leaf_splits_->max_constraint(),
&smaller_split); &smaller_split);
smaller_split.feature = real_fidx; smaller_split.feature = real_fidx;
if (smaller_split > smaller_best[tid]) { if (smaller_split > smaller_best[tid]) {
...@@ -501,6 +503,8 @@ void SerialTreeLearner::FindBestSplitsFromHistograms(const std::vector<int8_t>& ...@@ -501,6 +503,8 @@ void SerialTreeLearner::FindBestSplitsFromHistograms(const std::vector<int8_t>&
larger_leaf_splits_->sum_gradients(), larger_leaf_splits_->sum_gradients(),
larger_leaf_splits_->sum_hessians(), larger_leaf_splits_->sum_hessians(),
larger_leaf_splits_->num_data_in_leaf(), larger_leaf_splits_->num_data_in_leaf(),
larger_leaf_splits_->min_constraint(),
larger_leaf_splits_->max_constraint(),
&larger_split); &larger_split);
larger_split.feature = real_fidx; larger_split.feature = real_fidx;
if (larger_split > larger_best[tid]) { if (larger_split > larger_best[tid]) {
...@@ -530,7 +534,8 @@ void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* ri ...@@ -530,7 +534,8 @@ void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* ri
const int inner_feature_index = train_data_->InnerFeatureIndex(best_split_info.feature); const int inner_feature_index = train_data_->InnerFeatureIndex(best_split_info.feature);
// left = parent // left = parent
*left_leaf = best_leaf; *left_leaf = best_leaf;
if (train_data_->FeatureBinMapper(inner_feature_index)->bin_type() == BinType::NumericalBin) { bool is_numerical_split = train_data_->FeatureBinMapper(inner_feature_index)->bin_type() == BinType::NumericalBin;
if (is_numerical_split) {
auto threshold_double = train_data_->RealThreshold(inner_feature_index, best_split_info.threshold); auto threshold_double = train_data_->RealThreshold(inner_feature_index, best_split_info.threshold);
// split tree, will return right leaf // split tree, will return right leaf
*right_leaf = tree->Split(best_leaf, *right_leaf = tree->Split(best_leaf,
...@@ -574,18 +579,29 @@ void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* ri ...@@ -574,18 +579,29 @@ void SerialTreeLearner::Split(Tree* tree, int best_leaf, int* left_leaf, int* ri
#ifdef DEBUG #ifdef DEBUG
CHECK(best_split_info.left_count == data_partition_->leaf_count(best_leaf)); CHECK(best_split_info.left_count == data_partition_->leaf_count(best_leaf));
#endif #endif
auto p_left = smaller_leaf_splits_.get();
auto p_right = larger_leaf_splits_.get();
// init the leaves that used on next iteration // init the leaves that used on next iteration
if (best_split_info.left_count < best_split_info.right_count) { if (best_split_info.left_count < best_split_info.right_count) {
smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(), smaller_leaf_splits_->Init(*left_leaf, data_partition_.get(), best_split_info.left_sum_gradient, best_split_info.left_sum_hessian);
best_split_info.left_sum_gradient, larger_leaf_splits_->Init(*right_leaf, data_partition_.get(), best_split_info.right_sum_gradient, best_split_info.right_sum_hessian);
best_split_info.left_sum_hessian);
larger_leaf_splits_->Init(*right_leaf, data_partition_.get(),
best_split_info.right_sum_gradient,
best_split_info.right_sum_hessian);
} else { } else {
smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(), best_split_info.right_sum_gradient, best_split_info.right_sum_hessian); smaller_leaf_splits_->Init(*right_leaf, data_partition_.get(), best_split_info.right_sum_gradient, best_split_info.right_sum_hessian);
larger_leaf_splits_->Init(*left_leaf, data_partition_.get(), best_split_info.left_sum_gradient, best_split_info.left_sum_hessian); larger_leaf_splits_->Init(*left_leaf, data_partition_.get(), best_split_info.left_sum_gradient, best_split_info.left_sum_hessian);
p_right = smaller_leaf_splits_.get();
p_left = larger_leaf_splits_.get();
}
p_left->SetValueConstraint(best_split_info.min_constraint, best_split_info.max_constraint);
p_right->SetValueConstraint(best_split_info.min_constraint, best_split_info.max_constraint);
if (is_numerical_split) {
double mid = (best_split_info.left_output + best_split_info.right_output) / 2.0f;
if (best_split_info.monotone_type < 0) {
p_left->SetValueConstraint(mid, best_split_info.max_constraint);
p_right->SetValueConstraint(best_split_info.min_constraint, mid);
} else if (best_split_info.monotone_type > 0) {
p_left->SetValueConstraint(best_split_info.min_constraint, mid);
p_right->SetValueConstraint(mid, best_split_info.max_constraint);
}
} }
} }
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
#include <cmath> #include <cmath>
#include <cstdint> #include <cstdint>
#include <cstring> #include <cstring>
#include <limits>
#include <functional> #include <functional>
...@@ -42,9 +43,11 @@ public: ...@@ -42,9 +43,11 @@ public:
std::vector<uint32_t> cat_threshold; std::vector<uint32_t> cat_threshold;
/*! \brief True if default split is left */ /*! \brief True if default split is left */
bool default_left = true; bool default_left = true;
int8_t monotone_type = 0;
double min_constraint = -std::numeric_limits<double>::max();
double max_constraint = std::numeric_limits<double>::max();
inline static int Size(int max_cat_threshold) { inline static int Size(int max_cat_threshold) {
return 2 * sizeof(int) + sizeof(uint32_t) + sizeof(bool) + sizeof(double) * 7 + sizeof(data_size_t) * 2 + max_cat_threshold * sizeof(uint32_t); return 2 * sizeof(int) + sizeof(uint32_t) + sizeof(bool) + sizeof(double) * 9 + sizeof(data_size_t) * 2 + max_cat_threshold * sizeof(uint32_t) + sizeof(int8_t);
} }
inline void CopyTo(char* buffer) const { inline void CopyTo(char* buffer) const {
...@@ -72,6 +75,12 @@ public: ...@@ -72,6 +75,12 @@ public:
buffer += sizeof(right_sum_hessian); buffer += sizeof(right_sum_hessian);
std::memcpy(buffer, &default_left, sizeof(default_left)); std::memcpy(buffer, &default_left, sizeof(default_left));
buffer += sizeof(default_left); buffer += sizeof(default_left);
std::memcpy(buffer, &monotone_type, sizeof(monotone_type));
buffer += sizeof(monotone_type);
std::memcpy(buffer, &min_constraint, sizeof(min_constraint));
buffer += sizeof(min_constraint);
std::memcpy(buffer, &max_constraint, sizeof(max_constraint));
buffer += sizeof(max_constraint);
std::memcpy(buffer, &num_cat_threshold, sizeof(num_cat_threshold)); std::memcpy(buffer, &num_cat_threshold, sizeof(num_cat_threshold));
buffer += sizeof(num_cat_threshold); buffer += sizeof(num_cat_threshold);
std::memcpy(buffer, cat_threshold.data(), sizeof(uint32_t) * num_cat_threshold); std::memcpy(buffer, cat_threshold.data(), sizeof(uint32_t) * num_cat_threshold);
...@@ -102,6 +111,12 @@ public: ...@@ -102,6 +111,12 @@ public:
buffer += sizeof(right_sum_hessian); buffer += sizeof(right_sum_hessian);
std::memcpy(&default_left, buffer, sizeof(default_left)); std::memcpy(&default_left, buffer, sizeof(default_left));
buffer += sizeof(default_left); buffer += sizeof(default_left);
std::memcpy(&monotone_type, buffer, sizeof(monotone_type));
buffer += sizeof(monotone_type);
std::memcpy(&min_constraint, buffer, sizeof(min_constraint));
buffer += sizeof(min_constraint);
std::memcpy(&max_constraint, buffer, sizeof(max_constraint));
buffer += sizeof(max_constraint);
std::memcpy(&num_cat_threshold, buffer, sizeof(num_cat_threshold)); std::memcpy(&num_cat_threshold, buffer, sizeof(num_cat_threshold));
buffer += sizeof(num_cat_threshold); buffer += sizeof(num_cat_threshold);
cat_threshold.resize(num_cat_threshold); cat_threshold.resize(num_cat_threshold);
......
...@@ -69,6 +69,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, b ...@@ -69,6 +69,7 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Init(const Dataset* train_data, b
feature_metas_[i].num_bin = train_data->FeatureNumBin(i); feature_metas_[i].num_bin = train_data->FeatureNumBin(i);
feature_metas_[i].default_bin = train_data->FeatureBinMapper(i)->GetDefaultBin(); feature_metas_[i].default_bin = train_data->FeatureBinMapper(i)->GetDefaultBin();
feature_metas_[i].missing_type = train_data->FeatureBinMapper(i)->missing_type(); feature_metas_[i].missing_type = train_data->FeatureBinMapper(i)->missing_type();
feature_metas_[i].monotone_type = train_data->FeatureMonotone(i);
if (train_data->FeatureBinMapper(i)->GetDefaultBin() == 0) { if (train_data->FeatureBinMapper(i)->GetDefaultBin() == 0) {
feature_metas_[i].bias = 1; feature_metas_[i].bias = 1;
} else { } else {
...@@ -290,6 +291,8 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits() { ...@@ -290,6 +291,8 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits() {
this->smaller_leaf_splits_->sum_gradients(), this->smaller_leaf_splits_->sum_gradients(),
this->smaller_leaf_splits_->sum_hessians(), this->smaller_leaf_splits_->sum_hessians(),
this->smaller_leaf_splits_->num_data_in_leaf(), this->smaller_leaf_splits_->num_data_in_leaf(),
this->smaller_leaf_splits_->min_constraint(),
this->smaller_leaf_splits_->max_constraint(),
&smaller_bestsplit_per_features[feature_index]); &smaller_bestsplit_per_features[feature_index]);
smaller_bestsplit_per_features[feature_index].feature = real_feature_index; smaller_bestsplit_per_features[feature_index].feature = real_feature_index;
// only has root leaf // only has root leaf
...@@ -307,6 +310,8 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits() { ...@@ -307,6 +310,8 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplits() {
this->larger_leaf_splits_->sum_gradients(), this->larger_leaf_splits_->sum_gradients(),
this->larger_leaf_splits_->sum_hessians(), this->larger_leaf_splits_->sum_hessians(),
this->larger_leaf_splits_->num_data_in_leaf(), this->larger_leaf_splits_->num_data_in_leaf(),
this->larger_leaf_splits_->min_constraint(),
this->larger_leaf_splits_->max_constraint(),
&larger_bestsplit_per_features[feature_index]); &larger_bestsplit_per_features[feature_index]);
larger_bestsplit_per_features[feature_index].feature = real_feature_index; larger_bestsplit_per_features[feature_index].feature = real_feature_index;
OMP_LOOP_EX_END(); OMP_LOOP_EX_END();
...@@ -391,6 +396,8 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(cons ...@@ -391,6 +396,8 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(cons
smaller_leaf_splits_global_->sum_gradients(), smaller_leaf_splits_global_->sum_gradients(),
smaller_leaf_splits_global_->sum_hessians(), smaller_leaf_splits_global_->sum_hessians(),
GetGlobalDataCountInLeaf(smaller_leaf_splits_global_->LeafIndex()), GetGlobalDataCountInLeaf(smaller_leaf_splits_global_->LeafIndex()),
smaller_leaf_splits_global_->min_constraint(),
smaller_leaf_splits_global_->max_constraint(),
&smaller_split); &smaller_split);
smaller_split.feature = real_feature_index; smaller_split.feature = real_feature_index;
if (smaller_split > smaller_bests_per_thread[tid]) { if (smaller_split > smaller_bests_per_thread[tid]) {
...@@ -413,6 +420,8 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(cons ...@@ -413,6 +420,8 @@ void VotingParallelTreeLearner<TREELEARNER_T>::FindBestSplitsFromHistograms(cons
larger_leaf_splits_global_->sum_gradients(), larger_leaf_splits_global_->sum_gradients(),
larger_leaf_splits_global_->sum_hessians(), larger_leaf_splits_global_->sum_hessians(),
GetGlobalDataCountInLeaf(larger_leaf_splits_global_->LeafIndex()), GetGlobalDataCountInLeaf(larger_leaf_splits_global_->LeafIndex()),
larger_leaf_splits_global_->min_constraint(),
larger_leaf_splits_global_->max_constraint(),
&larger_split); &larger_split);
larger_split.feature = real_feature_index; larger_split.feature = real_feature_index;
if (larger_split > larger_best_per_thread[tid]) { if (larger_split > larger_best_per_thread[tid]) {
...@@ -457,6 +466,8 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Split(Tree* tree, int best_Leaf, ...@@ -457,6 +466,8 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Split(Tree* tree, int best_Leaf,
// set the global number of data for leaves // set the global number of data for leaves
global_data_count_in_leaf_[*left_leaf] = best_split_info.left_count; global_data_count_in_leaf_[*left_leaf] = best_split_info.left_count;
global_data_count_in_leaf_[*right_leaf] = best_split_info.right_count; global_data_count_in_leaf_[*right_leaf] = best_split_info.right_count;
auto p_left = smaller_leaf_splits_global_.get();
auto p_right = larger_leaf_splits_global_.get();
// init the global sumup info // init the global sumup info
if (best_split_info.left_count < best_split_info.right_count) { if (best_split_info.left_count < best_split_info.right_count) {
smaller_leaf_splits_global_->Init(*left_leaf, this->data_partition_.get(), smaller_leaf_splits_global_->Init(*left_leaf, this->data_partition_.get(),
...@@ -472,6 +483,22 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Split(Tree* tree, int best_Leaf, ...@@ -472,6 +483,22 @@ void VotingParallelTreeLearner<TREELEARNER_T>::Split(Tree* tree, int best_Leaf,
larger_leaf_splits_global_->Init(*left_leaf, this->data_partition_.get(), larger_leaf_splits_global_->Init(*left_leaf, this->data_partition_.get(),
best_split_info.left_sum_gradient, best_split_info.left_sum_gradient,
best_split_info.left_sum_hessian); best_split_info.left_sum_hessian);
p_left = larger_leaf_splits_global_.get();
p_right = smaller_leaf_splits_global_.get();
}
const int inner_feature_index = this->train_data_->InnerFeatureIndex(best_split_info.feature);
bool is_numerical_split = this->train_data_->FeatureBinMapper(inner_feature_index)->bin_type() == BinType::NumericalBin;
p_left->SetValueConstraint(best_split_info.min_constraint, best_split_info.max_constraint);
p_right->SetValueConstraint(best_split_info.min_constraint, best_split_info.max_constraint);
if (is_numerical_split) {
double mid = (best_split_info.left_output + best_split_info.right_output) / 2.0f;
if (best_split_info.monotone_type < 0) {
p_left->SetValueConstraint(mid, best_split_info.max_constraint);
p_right->SetValueConstraint(best_split_info.min_constraint, mid);
} else if (best_split_info.monotone_type > 0) {
p_left->SetValueConstraint(best_split_info.min_constraint, mid);
p_right->SetValueConstraint(mid, best_split_info.max_constraint);
}
} }
} }
......
...@@ -599,3 +599,45 @@ class TestEngine(unittest.TestCase): ...@@ -599,3 +599,45 @@ class TestEngine(unittest.TestCase):
assert np.all(sliced_csr == features) assert np.all(sliced_csr == features)
sliced_pred = train_and_get_predictions(sliced_csr, sliced_labels) sliced_pred = train_and_get_predictions(sliced_csr, sliced_labels)
np.testing.assert_almost_equal(origin_pred, sliced_pred) np.testing.assert_almost_equal(origin_pred, sliced_pred)
def test_monotone_constraint(self):
def is_increasing(y):
return np.count_nonzero(np.diff(y) < 0.0) == 0
def is_decreasing(y):
return np.count_nonzero(np.diff(y) > 0.0) == 0
def is_correctly_constrained(learner):
n = 200
variable_x = np.linspace(0, 1, n).reshape((n, 1))
fixed_xs_values = np.linspace(0, 1, n)
for i in range(n):
fixed_x = fixed_xs_values[i] * np.ones((n, 1))
monotonically_increasing_x = np.column_stack((variable_x, fixed_x))
monotonically_increasing_y = learner.predict(monotonically_increasing_x)
monotonically_decreasing_x = np.column_stack((fixed_x, variable_x))
monotonically_decreasing_y = learner.predict(monotonically_decreasing_x)
if not (is_increasing(monotonically_increasing_y) and is_decreasing(monotonically_decreasing_y)):
return False
return True
number_of_dpoints = 3000
x1_positively_correlated_with_y = np.random.random(size=number_of_dpoints)
x2_negatively_correlated_with_y = np.random.random(size=number_of_dpoints)
x = np.column_stack((x1_positively_correlated_with_y, x2_negatively_correlated_with_y))
zs = np.random.normal(loc=0.0, scale=0.01, size=number_of_dpoints)
y = (
5 * x1_positively_correlated_with_y +
np.sin(10 * np.pi * x1_positively_correlated_with_y) -
5 * x2_negatively_correlated_with_y -
np.cos(10 * np.pi * x2_negatively_correlated_with_y) +
zs
)
trainset = lgb.Dataset(x, label=y)
params = {
'min_data': 20,
'num_leaves': 20,
'monotone_constraints': '1,-1'
}
constrained_model = lgb.train(params, trainset)
assert is_correctly_constrained(constrained_model)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment