Commit 629fc047 authored by Guolin Ke's avatar Guolin Ke
Browse files

more flexity python basic object

parent b41e0f0a
...@@ -37,6 +37,7 @@ public: ...@@ -37,6 +37,7 @@ public:
/*! /*!
* \brief Merge model from other boosting object * \brief Merge model from other boosting object
Will insert to the front of current boosting object
* \param other * \param other
*/ */
virtual void MergeFrom(const Boosting* other) = 0; virtual void MergeFrom(const Boosting* other) = 0;
......
This diff is collapsed.
...@@ -46,12 +46,12 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_ ...@@ -46,12 +46,12 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
gbdt_config_ = config; gbdt_config_ = config;
early_stopping_round_ = gbdt_config_->early_stopping_round; early_stopping_round_ = gbdt_config_->early_stopping_round;
shrinkage_rate_ = gbdt_config_->learning_rate; shrinkage_rate_ = gbdt_config_->learning_rate;
train_data_ = train_data; random_ = Random(gbdt_config_->bagging_seed);
// create tree learner // create tree learner
tree_learner_.clear(); tree_learner_.clear();
for (int i = 0; i < num_class_; ++i) { for (int i = 0; i < num_class_; ++i) {
auto new_tree_learner = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->tree_config)); auto new_tree_learner = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->tree_config));
new_tree_learner->Init(train_data_); new_tree_learner->Init(train_data);
// init tree learner // init tree learner
tree_learner_.push_back(std::move(new_tree_learner)); tree_learner_.push_back(std::move(new_tree_learner));
} }
...@@ -63,24 +63,33 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_ ...@@ -63,24 +63,33 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
training_metrics_.push_back(metric); training_metrics_.push_back(metric);
} }
training_metrics_.shrink_to_fit(); training_metrics_.shrink_to_fit();
// create score tracker
train_score_updater_.reset(new ScoreUpdater(train_data_, num_class_));
num_data_ = train_data_->num_data();
// create buffer for gradients and hessians
if (object_function_ != nullptr) {
gradients_ = std::vector<score_t>(num_data_ * num_class_);
hessians_ = std::vector<score_t>(num_data_ * num_class_);
}
sigmoid_ = -1.0f; sigmoid_ = -1.0f;
if (object_function_ != nullptr if (object_function_ != nullptr
&& std::string(object_function_->GetName()) == std::string("binary")) { && std::string(object_function_->GetName()) == std::string("binary")) {
// only binary classification need sigmoid transform // only binary classification need sigmoid transform
sigmoid_ = gbdt_config_->sigmoid; sigmoid_ = gbdt_config_->sigmoid;
} }
if (train_data_ != train_data) {
// not same training data, need reset score and others
// create score tracker
train_score_updater_.reset(new ScoreUpdater(train_data, num_class_));
// update score
for (int i = 0; i < iter_; ++i) {
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
auto curr_tree = (i + num_init_iteration_) * num_class_ + curr_class;
train_score_updater_->AddScore(models_[curr_tree].get(), curr_class);
}
}
num_data_ = train_data->num_data();
// create buffer for gradients and hessians
if (object_function_ != nullptr) {
gradients_ = std::vector<score_t>(num_data_ * num_class_);
hessians_ = std::vector<score_t>(num_data_ * num_class_);
}
// get max feature index // get max feature index
max_feature_idx_ = train_data_->num_total_features() - 1; max_feature_idx_ = train_data->num_total_features() - 1;
// get label index // get label index
label_idx_ = train_data_->label_idx(); label_idx_ = train_data->label_idx();
// if need bagging, create buffer // if need bagging, create buffer
if (gbdt_config_->bagging_fraction < 1.0 && gbdt_config_->bagging_freq > 0) { if (gbdt_config_->bagging_fraction < 1.0 && gbdt_config_->bagging_freq > 0) {
out_of_bag_data_indices_ = std::vector<data_size_t>(num_data_); out_of_bag_data_indices_ = std::vector<data_size_t>(num_data_);
...@@ -91,14 +100,8 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_ ...@@ -91,14 +100,8 @@ void GBDT::ResetTrainingData(const BoostingConfig* config, const Dataset* train_
bag_data_cnt_ = num_data_; bag_data_cnt_ = num_data_;
bag_data_indices_.clear(); bag_data_indices_.clear();
} }
random_ = Random(gbdt_config_->bagging_seed);
// update score
for (int i = 0; i < iter_; ++i) {
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
auto curr_tree = i * num_class_ + curr_class;
train_score_updater_->AddScore(models_[curr_tree].get(), curr_class);
}
} }
train_data_ = train_data;
} }
void GBDT::AddValidDataset(const Dataset* valid_data, void GBDT::AddValidDataset(const Dataset* valid_data,
...@@ -111,7 +114,7 @@ void GBDT::AddValidDataset(const Dataset* valid_data, ...@@ -111,7 +114,7 @@ void GBDT::AddValidDataset(const Dataset* valid_data,
// update score // update score
for (int i = 0; i < iter_; ++i) { for (int i = 0; i < iter_; ++i) {
for (int curr_class = 0; curr_class < num_class_; ++curr_class) { for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
auto curr_tree = i * num_class_ + curr_class; auto curr_tree = (i + num_init_iteration_) * num_class_ + curr_class;
new_score_updater->AddScore(models_[curr_tree].get(), curr_class); new_score_updater->AddScore(models_[curr_tree].get(), curr_class);
} }
} }
...@@ -232,7 +235,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is ...@@ -232,7 +235,7 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
void GBDT::RollbackOneIter() { void GBDT::RollbackOneIter() {
if (iter_ == 0) { return; } if (iter_ == 0) { return; }
int cur_iter = iter_ - 1; int cur_iter = iter_ + num_init_iteration_ - 1;
// reset score // reset score
for (int curr_class = 0; curr_class < num_class_; ++curr_class) { for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
auto curr_tree = cur_iter * num_class_ + curr_class; auto curr_tree = cur_iter * num_class_ + curr_class;
......
...@@ -36,12 +36,28 @@ public: ...@@ -36,12 +36,28 @@ public:
const std::vector<const Metric*>& training_metrics) const std::vector<const Metric*>& training_metrics)
override; override;
/*!
* \brief Merge model from other boosting object
Will insert to the front of current boosting object
* \param other
*/
void MergeFrom(const Boosting* other) override { void MergeFrom(const Boosting* other) override {
auto other_gbdt = reinterpret_cast<const GBDT*>(other); auto other_gbdt = reinterpret_cast<const GBDT*>(other);
// tmp move to other vector
auto original_models = std::move(models_);
models_ = std::vector<std::unique_ptr<Tree>>();
// push model from other first
for (const auto& tree : other_gbdt->models_) { for (const auto& tree : other_gbdt->models_) {
auto new_tree = std::unique_ptr<Tree>(new Tree(*(tree.get()))); auto new_tree = std::unique_ptr<Tree>(new Tree(*(tree.get())));
models_.push_back(std::move(new_tree)); models_.push_back(std::move(new_tree));
} }
num_init_iteration_ = static_cast<int>(models_.size()) / num_class_;
// push model in current object
for (const auto& tree : original_models) {
auto new_tree = std::unique_ptr<Tree>(new Tree(*(tree.get())));
models_.push_back(std::move(new_tree));
}
num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_class_;
} }
/*! /*!
...@@ -266,6 +282,7 @@ protected: ...@@ -266,6 +282,7 @@ protected:
int num_iteration_for_pred_; int num_iteration_for_pred_;
/*! \brief Shrinkage rate for one iteration */ /*! \brief Shrinkage rate for one iteration */
double shrinkage_rate_; double shrinkage_rate_;
/*! \brief Number of loaded initial models */
int num_init_iteration_; int num_init_iteration_;
}; };
......
...@@ -36,7 +36,7 @@ public: ...@@ -36,7 +36,7 @@ public:
Log::Warning("continued train from model is not support for c_api, \ Log::Warning("continued train from model is not support for c_api, \
please use continued train with input score"); please use continued train with input score");
} }
boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, "")); boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, nullptr));
ConstructObjectAndTrainingMetrics(train_data); ConstructObjectAndTrainingMetrics(train_data);
// initialize the boosting // initialize the boosting
boosting_->Init(&config_.boosting_config, train_data, objective_fun_.get(), boosting_->Init(&config_.boosting_config, train_data, objective_fun_.get(),
...@@ -114,6 +114,10 @@ public: ...@@ -114,6 +114,10 @@ public:
return boosting_->TrainOneIter(gradients, hessians, false); return boosting_->TrainOneIter(gradients, hessians, false);
} }
void RollbackOneIter() {
boosting_->RollbackOneIter();
}
void PrepareForPrediction(int num_iteration, int predict_type) { void PrepareForPrediction(int num_iteration, int predict_type) {
boosting_->SetNumIterationForPred(num_iteration); boosting_->SetNumIterationForPred(num_iteration);
bool is_predict_leaf = false; bool is_predict_leaf = false;
...@@ -156,24 +160,13 @@ public: ...@@ -156,24 +160,13 @@ public:
int idx = 0; int idx = 0;
for (const auto& metric : train_metric_) { for (const auto& metric : train_metric_) {
for (const auto& name : metric->GetName()) { for (const auto& name : metric->GetName()) {
int j = 0; std::strcpy(out_strs[idx], name.c_str());
auto name_cstr = name.c_str();
while (name_cstr[j] != '\0') {
out_strs[idx][j] = name_cstr[j];
++j;
}
out_strs[idx][j] = '\0';
++idx; ++idx;
} }
} }
return idx; return idx;
} }
void RollbackOneIter() {
boosting_->RollbackOneIter();
}
const Boosting* GetBoosting() const { return boosting_.get(); } const Boosting* GetBoosting() const { return boosting_.get(); }
private: private:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment