Unverified Commit 1dbe5e99 authored by shiyu1994's avatar shiyu1994 Committed by GitHub
Browse files

create buffer for gradients and hessians with goss and customized objective (fixes #3243) (#3263)



* fix bug for GOSS with customized objective (fixes #3243)

* Apply suggestions from code review
Co-authored-by: default avatarGuolin Ke <guolin.ke@outlook.com>
parent 4f28233b
......@@ -36,6 +36,12 @@ class GOSS: public GBDT {
const std::vector<const Metric*>& training_metrics) override {
GBDT::Init(config, train_data, objective_function, training_metrics);
ResetGoss();
if (objective_function_ == nullptr) {
// use customized objective function
size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
gradients_.resize(total_size, 0.0f);
hessians_.resize(total_size, 0.0f);
}
}
void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* objective_function,
......@@ -49,6 +55,23 @@ class GOSS: public GBDT {
ResetGoss();
}
bool TrainOneIter(const score_t* gradients, const score_t* hessians) override {
if (gradients != nullptr) {
// use customized objective function
CHECK(hessians != nullptr && objective_function_ == nullptr);
size_t total_size = static_cast<size_t>(num_data_) * num_tree_per_iteration_;
#pragma omp parallel for schedule(static)
for (size_t i = 0; i < total_size; ++i) {
gradients_[i] = gradients[i];
hessians_[i] = hessians[i];
}
return GBDT::TrainOneIter(gradients_.data(), hessians_.data());
} else {
CHECK(hessians == nullptr);
return GBDT::TrainOneIter(nullptr, nullptr);
}
}
void ResetGoss() {
CHECK_LE(config_->top_rate + config_->other_rate, 1.0f);
CHECK(config_->top_rate > 0.0f && config_->other_rate > 0.0f);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment