"src/vscode:/vscode.git/clone" did not exist on "5bee6489ac4293328c826e72e7de339206c456da"
Unverified Commit f8f6c513 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

disable monotone constraint in objective functions with renew_tree_output (#3380)

* Update gbdt.cpp

* Update gbdt.cpp

* Apply suggestions from code review
parent f30dbe87
...@@ -79,6 +79,9 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective ...@@ -79,6 +79,9 @@ void GBDT::Init(const Config* config, const Dataset* train_data, const Objective
num_tree_per_iteration_ = num_class_; num_tree_per_iteration_ = num_class_;
if (objective_function_ != nullptr) { if (objective_function_ != nullptr) {
num_tree_per_iteration_ = objective_function_->NumModelPerIteration(); num_tree_per_iteration_ = objective_function_->NumModelPerIteration();
if (objective_function_->IsRenewTreeOutput() && !config->monotone_constraints.empty()) {
Log::Fatal("Cannot use ``monotone_constraints`` in %s objective, please disable it.", objective_function_->GetName());
}
} }
is_constant_hessian_ = GetIsConstHessian(objective_function); is_constant_hessian_ = GetIsConstHessian(objective_function);
...@@ -665,6 +668,9 @@ void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* ...@@ -665,6 +668,9 @@ void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction*
objective_function_ = objective_function; objective_function_ = objective_function;
if (objective_function_ != nullptr) { if (objective_function_ != nullptr) {
CHECK_EQ(num_tree_per_iteration_, objective_function_->NumModelPerIteration()); CHECK_EQ(num_tree_per_iteration_, objective_function_->NumModelPerIteration());
if (objective_function_->IsRenewTreeOutput() && !config_->monotone_constraints.empty()) {
Log::Fatal("Cannot use ``monotone_constraints`` in %s objective, please disable it.", objective_function_->GetName());
}
} }
is_constant_hessian_ = GetIsConstHessian(objective_function); is_constant_hessian_ = GetIsConstHessian(objective_function);
...@@ -718,6 +724,9 @@ void GBDT::ResetConfig(const Config* config) { ...@@ -718,6 +724,9 @@ void GBDT::ResetConfig(const Config* config) {
if (!config->feature_contri.empty()) { if (!config->feature_contri.empty()) {
CHECK_EQ(static_cast<size_t>(train_data_->num_total_features()), config->feature_contri.size()); CHECK_EQ(static_cast<size_t>(train_data_->num_total_features()), config->feature_contri.size());
} }
if (objective_function_ != nullptr && objective_function_->IsRenewTreeOutput() && !config->monotone_constraints.empty()) {
Log::Fatal("Cannot use ``monotone_constraints`` in %s objective, please disable it.", objective_function_->GetName());
}
early_stopping_round_ = new_config->early_stopping_round; early_stopping_round_ = new_config->early_stopping_round;
shrinkage_rate_ = new_config->learning_rate; shrinkage_rate_ = new_config->learning_rate;
if (tree_learner_ != nullptr) { if (tree_learner_ != nullptr) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment