Commit 4accb9d4 authored by Guolin Ke's avatar Guolin Ke
Browse files

support merge two booster

parent 14a67b7e
...@@ -35,6 +35,11 @@ public: ...@@ -35,6 +35,11 @@ public:
const ObjectiveFunction* object_function, const ObjectiveFunction* object_function,
const std::vector<const Metric*>& training_metrics) = 0; const std::vector<const Metric*>& training_metrics) = 0;
/*!
* \brief Merge model from other boosting object
* \param other
*/
virtual void MergeFrom(const Boosting* other) = 0;
/*! /*!
* \brief Reset Config for current boosting * \brief Reset Config for current boosting
* \param config Configs for boosting * \param config Configs for boosting
...@@ -179,6 +184,7 @@ public: ...@@ -179,6 +184,7 @@ public:
Boosting(const Boosting&) = delete; Boosting(const Boosting&) = delete;
static void LoadFileToBoosting(Boosting* boosting, const char* filename); static void LoadFileToBoosting(Boosting* boosting, const char* filename);
/*! /*!
* \brief Create boosting object * \brief Create boosting object
* \param type Type of boosting * \param type Type of boosting
......
...@@ -240,8 +240,18 @@ DllExport int LGBM_BoosterCreateFromModelfile( ...@@ -240,8 +240,18 @@ DllExport int LGBM_BoosterCreateFromModelfile(
*/ */
DllExport int LGBM_BoosterFree(BoosterHandle handle); DllExport int LGBM_BoosterFree(BoosterHandle handle);
/*!
* \brief Merge model in two booster to first handle
* \param handle handle, will merge other handle to this
* \param other_handle
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterMerge(BoosterHandle handle,
BoosterHandle other_handle);
/*! /*!
* \brief Add new validation to booster * \brief Add new validation to booster
* \param handle handle
* \param valid_data validation data set * \param valid_data validation data set
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
...@@ -249,7 +259,8 @@ DllExport int LGBM_BoosterAddValidData(BoosterHandle handle, ...@@ -249,7 +259,8 @@ DllExport int LGBM_BoosterAddValidData(BoosterHandle handle,
const DatesetHandle valid_data); const DatesetHandle valid_data);
/*! /*!
* \brief Add new validation to booster * \brief Reset training data for booster
* \param handle handle
* \param train_data training data set * \param train_data training data set
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
...@@ -258,6 +269,7 @@ DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle, ...@@ -258,6 +269,7 @@ DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle,
/*! /*!
* \brief Reset config for current booster * \brief Reset config for current booster
* \param handle handle
* \param parameters format: 'key1=value1 key2=value2' * \param parameters format: 'key1=value1 key2=value2'
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
...@@ -265,6 +277,7 @@ DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* param ...@@ -265,6 +277,7 @@ DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* param
/*! /*!
* \brief Get number of class * \brief Get number of class
* \param handle handle
* \return number of class * \return number of class
*/ */
DllExport int LGBM_BoosterGetNumClasses(BoosterHandle handle, int64_t* out_len); DllExport int LGBM_BoosterGetNumClasses(BoosterHandle handle, int64_t* out_len);
......
...@@ -101,10 +101,6 @@ public: ...@@ -101,10 +101,6 @@ public:
/*! \brief Serialize this object by string*/ /*! \brief Serialize this object by string*/
std::string ToString(); std::string ToString();
/*! \brief Disable copy */
Tree& operator=(const Tree&) = delete;
/*! \brief Disable copy */
Tree(const Tree&) = delete;
private: private:
/*! /*!
* \brief Find leaf index of which record belongs by data * \brief Find leaf index of which record belongs by data
......
...@@ -36,6 +36,14 @@ public: ...@@ -36,6 +36,14 @@ public:
const std::vector<const Metric*>& training_metrics) const std::vector<const Metric*>& training_metrics)
override; override;
void MergeFrom(const Boosting* other) override {
auto other_gbdt = reinterpret_cast<const GBDT*>(other);
for (const auto& tree : other_gbdt->models_) {
auto new_tree = std::unique_ptr<Tree>(new Tree(*(tree.get())));
models_.push_back(std::move(new_tree));
}
}
/*! /*!
* \brief Reset Config for current boosting * \brief Reset Config for current boosting
* \param config Configs for boosting * \param config Configs for boosting
......
...@@ -42,6 +42,10 @@ public: ...@@ -42,6 +42,10 @@ public:
Common::ConstPtrInVectorWrapper<Metric>(train_metric_)); Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
} }
void MergeFrom(const Booster* other) {
boosting_->MergeFrom(other->boosting_.get());
}
~Booster() { ~Booster() {
} }
...@@ -465,6 +469,14 @@ DllExport int LGBM_BoosterFree(BoosterHandle handle) { ...@@ -465,6 +469,14 @@ DllExport int LGBM_BoosterFree(BoosterHandle handle) {
API_END(); API_END();
} }
DllExport int LGBM_BoosterMerge(BoosterHandle handle,
BoosterHandle other_handle) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
Booster* ref_other_booster = reinterpret_cast<Booster*>(other_handle);
ref_booster->MergeFrom(ref_other_booster);
API_END();
}
DllExport int LGBM_BoosterAddValidData(BoosterHandle handle, DllExport int LGBM_BoosterAddValidData(BoosterHandle handle,
const DatesetHandle valid_data) { const DatesetHandle valid_data) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment