Commit 14a67b7e authored by Guolin Ke's avatar Guolin Ke
Browse files

support dynamic change training data and add validation data

parent 13329682
...@@ -51,6 +51,18 @@ public: ...@@ -51,6 +51,18 @@ public:
explicit BinMapper(const void* memory); explicit BinMapper(const void* memory);
~BinMapper(); ~BinMapper();
bool CheckAlign(const BinMapper& other) const {
if (num_bin_ != other.num_bin_) {
return false;
}
for (int i = 0; i < num_bin_; ++i) {
if (bin_upper_bound_[i] != other.bin_upper_bound_[i]) {
return false;
}
}
return true;
}
/*! \brief Get number of bins */ /*! \brief Get number of bins */
inline int num_bin() const { return num_bin_; } inline int num_bin() const { return num_bin_; }
/*! \brief True if bin is trival (contains only one bin) */ /*! \brief True if bin is trival (contains only one bin) */
......
...@@ -41,12 +41,20 @@ public: ...@@ -41,12 +41,20 @@ public:
*/ */
virtual void ResetConfig(const BoostingConfig* config) = 0; virtual void ResetConfig(const BoostingConfig* config) = 0;
/*!
* \brief Reset training data for current boosting
* \param train_data Training data
* \param object_function Training objective function
* \param training_metrics Training metric
*/
virtual void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector<const Metric*>& training_metrics) = 0;
/*! /*!
* \brief Add a validation data * \brief Add a validation data
* \param valid_data Validation data * \param valid_data Validation data
* \param valid_metrics Metric for validation data * \param valid_metrics Metric for validation data
*/ */
virtual void AddDataset(const Dataset* valid_data, virtual void AddValidDataset(const Dataset* valid_data,
const std::vector<const Metric*>& valid_metrics) = 0; const std::vector<const Metric*>& valid_metrics) = 0;
/*! /*!
......
...@@ -212,19 +212,12 @@ DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle, ...@@ -212,19 +212,12 @@ DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
/*! /*!
* \brief create an new boosting learner * \brief create an new boosting learner
* \param train_data training data set * \param train_data training data set
* \param valid_datas validation data sets
* \param valid_names names of validation data sets
* \param n_valid_datas number of validation set
* \param parameters format: 'key1=value1 key2=value2' * \param parameters format: 'key1=value1 key2=value2'
* \param init_model_filename filename of model
* \prama out handle of created Booster * \prama out handle of created Booster
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
*/ */
DllExport int LGBM_BoosterCreate(const DatesetHandle train_data, DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
const DatesetHandle valid_datas[],
int n_valid_datas,
const char* parameters, const char* parameters,
const char* init_model_filename,
BoosterHandle* out); BoosterHandle* out);
/*! /*!
...@@ -247,6 +240,22 @@ DllExport int LGBM_BoosterCreateFromModelfile( ...@@ -247,6 +240,22 @@ DllExport int LGBM_BoosterCreateFromModelfile(
*/ */
DllExport int LGBM_BoosterFree(BoosterHandle handle); DllExport int LGBM_BoosterFree(BoosterHandle handle);
/*!
* \brief Add new validation to booster
* \param valid_data validation data set
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterAddValidData(BoosterHandle handle,
const DatesetHandle valid_data);
/*!
* \brief Add new validation to booster
* \param train_data training data set
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle,
const DatesetHandle train_data);
/*! /*!
* \brief Reset config for current booster * \brief Reset config for current booster
* \param parameters format: 'key1=value1 key2=value2' * \param parameters format: 'key1=value1 key2=value2'
......
...@@ -277,6 +277,27 @@ public: ...@@ -277,6 +277,27 @@ public:
/*! \brief Destructor */ /*! \brief Destructor */
~Dataset(); ~Dataset();
bool CheckAlign(const Dataset& other) const {
if (num_features_ != other.num_features_) {
return false;
}
if (num_total_features_ != other.num_total_features_) {
return false;
}
if (num_class_ != other.num_class_) {
return false;
}
if (label_idx_ != other.label_idx_) {
return false;
}
for (int i = 0; i < num_features_; ++i) {
if (!features_[i]->CheckAlign(*(other.features_[i].get()))) {
return false;
}
}
return true;
}
inline void PushOneRow(int tid, data_size_t row_idx, const std::vector<double>& feature_values) { inline void PushOneRow(int tid, data_size_t row_idx, const std::vector<double>& feature_values) {
for (size_t i = 0; i < feature_values.size() && i < static_cast<size_t>(num_total_features_); ++i) { for (size_t i = 0; i < feature_values.size() && i < static_cast<size_t>(num_total_features_); ++i) {
int feature_idx = used_feature_map_[i]; int feature_idx = used_feature_map_[i];
......
...@@ -63,6 +63,13 @@ public: ...@@ -63,6 +63,13 @@ public:
~Feature() { ~Feature() {
} }
bool CheckAlign(const Feature& other) const {
if (feature_index_ != other.feature_index_) {
return false;
}
return bin_mapper_->CheckAlign(*(other.bin_mapper_.get()));
}
/*! /*!
* \brief Push one record, will auto convert to bin and push to bin data * \brief Push one record, will auto convert to bin and push to bin data
* \param tid Thread id * \param tid Thread id
......
...@@ -207,7 +207,7 @@ void Application::InitTrain() { ...@@ -207,7 +207,7 @@ void Application::InitTrain() {
Common::ConstPtrInVectorWrapper<Metric>(train_metric_)); Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
// add validation data into boosting // add validation data into boosting
for (size_t i = 0; i < valid_datas_.size(); ++i) { for (size_t i = 0; i < valid_datas_.size(); ++i) {
boosting_->AddDataset(valid_datas_[i].get(), boosting_->AddValidDataset(valid_datas_[i].get(),
Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_[i])); Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_[i]));
} }
Log::Info("Finished initializing training"); Log::Info("Finished initializing training");
......
...@@ -16,7 +16,10 @@ ...@@ -16,7 +16,10 @@
namespace LightGBM { namespace LightGBM {
GBDT::GBDT() : saved_model_size_(-1), num_iteration_for_pred_(0) { GBDT::GBDT()
:saved_model_size_(-1),
num_iteration_for_pred_(0),
num_init_iteration_(0) {
} }
...@@ -33,8 +36,46 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O ...@@ -33,8 +36,46 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
max_feature_idx_ = 0; max_feature_idx_ = 0;
early_stopping_round_ = gbdt_config_->early_stopping_round; early_stopping_round_ = gbdt_config_->early_stopping_round;
shrinkage_rate_ = gbdt_config_->learning_rate; shrinkage_rate_ = gbdt_config_->learning_rate;
train_data_ = train_data;
num_class_ = config->num_class; num_class_ = config->num_class;
train_data_ = nullptr;
ResetTrainingData(train_data, object_function, training_metrics);
// initialize random generator
random_ = Random(gbdt_config_->bagging_seed);
}
void GBDT::ResetConfig(const BoostingConfig* config) {
gbdt_config_ = config;
early_stopping_round_ = gbdt_config_->early_stopping_round;
shrinkage_rate_ = gbdt_config_->learning_rate;
// create tree learner
tree_learner_.clear();
for (int i = 0; i < num_class_; ++i) {
auto new_tree_learner = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->tree_config));
new_tree_learner->Init(train_data_);
// init tree learner
tree_learner_.push_back(std::move(new_tree_learner));
}
tree_learner_.shrink_to_fit();
// if need bagging, create buffer
if (gbdt_config_->bagging_fraction < 1.0 && gbdt_config_->bagging_freq > 0) {
out_of_bag_data_indices_ = std::vector<data_size_t>(num_data_);
bag_data_indices_ = std::vector<data_size_t>(num_data_);
} else {
out_of_bag_data_cnt_ = 0;
out_of_bag_data_indices_.clear();
bag_data_cnt_ = num_data_;
bag_data_indices_.clear();
}
// initialize random generator
random_ = Random(gbdt_config_->bagging_seed);
}
void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector<const Metric*>& training_metrics) {
if (train_data_ != nullptr && !train_data_->CheckAlign(*train_data)) {
Log::Fatal("cannot reset training data, since new training data has different bin mappers");
}
train_data_ = train_data;
// create tree learner // create tree learner
tree_learner_.clear(); tree_learner_.clear();
for (int i = 0; i < num_class_; ++i) { for (int i = 0; i < num_class_; ++i) {
...@@ -46,6 +87,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O ...@@ -46,6 +87,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
tree_learner_.shrink_to_fit(); tree_learner_.shrink_to_fit();
object_function_ = object_function; object_function_ = object_function;
// push training metrics // push training metrics
training_metrics_.clear();
for (const auto& metric : training_metrics) { for (const auto& metric : training_metrics) {
training_metrics_.push_back(metric); training_metrics_.push_back(metric);
} }
...@@ -78,44 +120,29 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O ...@@ -78,44 +120,29 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
bag_data_cnt_ = num_data_; bag_data_cnt_ = num_data_;
bag_data_indices_.clear(); bag_data_indices_.clear();
} }
// initialize random generator // update score
random_ = Random(gbdt_config_->bagging_seed); for (int i = 0; i < iter_; ++i) {
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
} auto curr_tree = i * num_class_ + curr_class;
train_score_updater_->AddScore(models_[curr_tree].get(), curr_class);
void GBDT::ResetConfig(const BoostingConfig* config) {
gbdt_config_ = config;
early_stopping_round_ = gbdt_config_->early_stopping_round;
shrinkage_rate_ = gbdt_config_->learning_rate;
// create tree learner
tree_learner_.clear();
for (int i = 0; i < num_class_; ++i) {
auto new_tree_learner = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->tree_config));
new_tree_learner->Init(train_data_);
// init tree learner
tree_learner_.push_back(std::move(new_tree_learner));
} }
tree_learner_.shrink_to_fit();
// if need bagging, create buffer
if (gbdt_config_->bagging_fraction < 1.0 && gbdt_config_->bagging_freq > 0) {
out_of_bag_data_indices_ = std::vector<data_size_t>(num_data_);
bag_data_indices_ = std::vector<data_size_t>(num_data_);
} else {
out_of_bag_data_cnt_ = 0;
out_of_bag_data_indices_.clear();
bag_data_cnt_ = num_data_;
bag_data_indices_.clear();
} }
// initialize random generator
random_ = Random(gbdt_config_->bagging_seed);
} }
void GBDT::AddDataset(const Dataset* valid_data,
void GBDT::AddValidDataset(const Dataset* valid_data,
const std::vector<const Metric*>& valid_metrics) { const std::vector<const Metric*>& valid_metrics) {
if (iter_ > 0) { if (!train_data_->CheckAlign(*valid_data)) {
Log::Fatal("Cannot add validation data after training started"); Log::Fatal("cannot add validation data, since it has different bin mappers with training data");
} }
// for a validation dataset, we need its score and metric // for a validation dataset, we need its score and metric
auto new_score_updater = std::unique_ptr<ScoreUpdater>(new ScoreUpdater(valid_data, num_class_)); auto new_score_updater = std::unique_ptr<ScoreUpdater>(new ScoreUpdater(valid_data, num_class_));
// update score
for (int i = 0; i < iter_; ++i) {
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
auto curr_tree = i * num_class_ + curr_class;
new_score_updater->AddScore(models_[curr_tree].get(), curr_class);
}
}
valid_score_updater_.push_back(std::move(new_score_updater)); valid_score_updater_.push_back(std::move(new_score_updater));
valid_metrics_.emplace_back(); valid_metrics_.emplace_back();
if (early_stopping_round_ > 0) { if (early_stopping_round_ > 0) {
...@@ -499,6 +526,7 @@ void GBDT::LoadModelFromString(const std::string& model_str) { ...@@ -499,6 +526,7 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
} }
Log::Info("Finished loading %d models", models_.size()); Log::Info("Finished loading %d models", models_.size());
num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_class_; num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_class_;
num_init_iteration_ = num_iteration_for_pred_;
} }
std::string GBDT::FeatureImportance() const { std::string GBDT::FeatureImportance() const {
......
...@@ -42,12 +42,20 @@ public: ...@@ -42,12 +42,20 @@ public:
*/ */
void ResetConfig(const BoostingConfig* config) override; void ResetConfig(const BoostingConfig* config) override;
/*!
* \brief Reset training data for current boosting
* \param train_data Training data
* \param object_function Training objective function
* \param training_metrics Training metric
*/
void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector<const Metric*>& training_metrics) override;
/*! /*!
* \brief Adding a validation dataset * \brief Adding a validation dataset
* \param valid_data Validation dataset * \param valid_data Validation dataset
* \param valid_metrics Metrics for validation dataset * \param valid_metrics Metrics for validation dataset
*/ */
void AddDataset(const Dataset* valid_data, void AddValidDataset(const Dataset* valid_data,
const std::vector<const Metric*>& valid_metrics) override; const std::vector<const Metric*>& valid_metrics) override;
/*! /*!
* \brief Training logic * \brief Training logic
...@@ -63,7 +71,7 @@ public: ...@@ -63,7 +71,7 @@ public:
*/ */
void RollbackOneIter() override; void RollbackOneIter() override;
int GetCurrentIteration() const override { return iter_; } int GetCurrentIteration() const override { return iter_ + num_init_iteration_; }
bool EvalAndCheckEarlyStopping() override; bool EvalAndCheckEarlyStopping() override;
...@@ -256,6 +264,7 @@ protected: ...@@ -256,6 +264,7 @@ protected:
int num_iteration_for_pred_; int num_iteration_for_pred_;
/*! \brief Shrinkage rate for one iteration */ /*! \brief Shrinkage rate for one iteration */
double shrinkage_rate_; double shrinkage_rate_;
int num_init_iteration_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
...@@ -28,9 +28,7 @@ public: ...@@ -28,9 +28,7 @@ public:
} }
Booster(const Dataset* train_data, Booster(const Dataset* train_data,
std::vector<const Dataset*> valid_data, const char* parameters) {
const char* parameters)
:train_data_(train_data), valid_datas_(valid_data) {
config_.LoadFromString(parameters); config_.LoadFromString(parameters);
// create boosting // create boosting
if (config_.io_config.input_model.size() > 0) { if (config_.io_config.input_model.size() > 0) {
...@@ -38,6 +36,17 @@ public: ...@@ -38,6 +36,17 @@ public:
please use continued train with input score"); please use continued train with input score");
} }
boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, "")); boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, ""));
ConstructObjectAndTrainingMetrics(train_data);
// initialize the boosting
boosting_->Init(&config_.boosting_config, train_data, objective_fun_.get(),
Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
}
~Booster() {
}
void ConstructObjectAndTrainingMetrics(const Dataset* train_data) {
// create objective function // create objective function
objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type, objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
config_.objective_config)); config_.objective_config));
...@@ -45,48 +54,39 @@ public: ...@@ -45,48 +54,39 @@ public:
Log::Warning("Using self-defined objective functions"); Log::Warning("Using self-defined objective functions");
} }
// create training metric // create training metric
train_metric_.clear();
for (auto metric_type : config_.metric_types) { for (auto metric_type : config_.metric_types) {
auto metric = std::unique_ptr<Metric>( auto metric = std::unique_ptr<Metric>(
Metric::CreateMetric(metric_type, config_.metric_config)); Metric::CreateMetric(metric_type, config_.metric_config));
if (metric == nullptr) { continue; } if (metric == nullptr) { continue; }
metric->Init(train_data_->metadata(), train_data_->num_data()); metric->Init(train_data->metadata(), train_data->num_data());
train_metric_.push_back(std::move(metric)); train_metric_.push_back(std::move(metric));
} }
train_metric_.shrink_to_fit(); train_metric_.shrink_to_fit();
// add metric for validation data
for (size_t i = 0; i < valid_datas_.size(); ++i) {
valid_metrics_.emplace_back();
for (auto metric_type : config_.metric_types) {
auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
if (metric == nullptr) { continue; }
metric->Init(valid_datas_[i]->metadata(), valid_datas_[i]->num_data());
valid_metrics_.back().push_back(std::move(metric));
}
valid_metrics_.back().shrink_to_fit();
}
valid_metrics_.shrink_to_fit();
// initialize the objective function // initialize the objective function
if (objective_fun_ != nullptr) { if (objective_fun_ != nullptr) {
objective_fun_->Init(train_data_->metadata(), train_data_->num_data()); objective_fun_->Init(train_data->metadata(), train_data->num_data());
}
// initialize the boosting
boosting_->Init(&config_.boosting_config, train_data_, objective_fun_.get(),
Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
// add validation data into boosting
for (size_t i = 0; i < valid_datas_.size(); ++i) {
boosting_->AddDataset(valid_datas_[i],
Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_[i]));
} }
} }
void LoadModelFromFile(const char* filename) { void ResetTrainingData(const Dataset* train_data) {
Boosting::LoadFileToBoosting(boosting_.get(), filename); ConstructObjectAndTrainingMetrics(train_data);
// initialize the boosting
boosting_->ResetTrainingData(train_data, objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
} }
~Booster() { void AddValidData(const Dataset* valid_data) {
valid_metrics_.emplace_back();
for (auto metric_type : config_.metric_types) {
auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
if (metric == nullptr) { continue; }
metric->Init(valid_data->metadata(), valid_data->num_data());
valid_metrics_.back().push_back(std::move(metric));
}
valid_metrics_.back().shrink_to_fit();
boosting_->AddValidDataset(valid_data,
Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
} }
bool TrainOneIter() { bool TrainOneIter() {
return boosting_->TrainOneIter(nullptr, nullptr, false); return boosting_->TrainOneIter(nullptr, nullptr, false);
} }
...@@ -151,9 +151,7 @@ public: ...@@ -151,9 +151,7 @@ public:
} }
void ResetBoostingConfig(const char* parameters) { void ResetBoostingConfig(const char* parameters) {
OverallConfig new_config; config_.LoadFromString(parameters);
new_config.LoadFromString(parameters);
config_.boosting_config = new_config.boosting_config;
boosting_->ResetConfig(&config_.boosting_config); boosting_->ResetConfig(&config_.boosting_config);
} }
...@@ -164,14 +162,9 @@ public: ...@@ -164,14 +162,9 @@ public:
const Boosting* GetBoosting() const { return boosting_.get(); } const Boosting* GetBoosting() const { return boosting_.get(); }
private: private:
std::unique_ptr<Boosting> boosting_; std::unique_ptr<Boosting> boosting_;
/*! \brief All configs */ /*! \brief All configs */
OverallConfig config_; OverallConfig config_;
/*! \brief Training data */
const Dataset* train_data_;
/*! \brief Validation data */
std::vector<const Dataset*> valid_datas_;
/*! \brief Metric for training data */ /*! \brief Metric for training data */
std::vector<std::unique_ptr<Metric>> train_metric_; std::vector<std::unique_ptr<Metric>> train_metric_;
/*! \brief Metrics for validation data */ /*! \brief Metrics for validation data */
...@@ -446,21 +439,11 @@ DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle, ...@@ -446,21 +439,11 @@ DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
// ---- start of booster // ---- start of booster
DllExport int LGBM_BoosterCreate(const DatesetHandle train_data, DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
const DatesetHandle valid_datas[],
int n_valid_datas,
const char* parameters, const char* parameters,
const char* init_model_filename,
BoosterHandle* out) { BoosterHandle* out) {
API_BEGIN(); API_BEGIN();
const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data); const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
std::vector<const Dataset*> p_valid_datas; auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
for (int i = 0; i < n_valid_datas; ++i) {
p_valid_datas.emplace_back(reinterpret_cast<const Dataset*>(valid_datas[i]));
}
auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, p_valid_datas, parameters));
if (init_model_filename != nullptr) {
ret->LoadModelFromFile(init_model_filename);
}
*out = ret.release(); *out = ret.release();
API_END(); API_END();
} }
...@@ -482,6 +465,25 @@ DllExport int LGBM_BoosterFree(BoosterHandle handle) { ...@@ -482,6 +465,25 @@ DllExport int LGBM_BoosterFree(BoosterHandle handle) {
API_END(); API_END();
} }
DllExport int LGBM_BoosterAddValidData(BoosterHandle handle,
const DatesetHandle valid_data) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
ref_booster->AddValidData(p_dataset);
API_END();
}
DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle,
const DatesetHandle train_data) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
ref_booster->ResetTrainingData(p_dataset);
API_END();
}
DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) { DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
API_BEGIN(); API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle); Booster* ref_booster = reinterpret_cast<Booster*>(handle);
......
...@@ -174,10 +174,10 @@ def test_dataset(): ...@@ -174,10 +174,10 @@ def test_dataset():
test_free_dataset(train) test_free_dataset(train)
def test_booster(): def test_booster():
train = test_load_from_mat('../../examples/binary_classification/binary.train', None) train = test_load_from_mat('../../examples/binary_classification/binary.train', None)
test = [test_load_from_mat('../../examples/binary_classification/binary.test', train)] test = test_load_from_mat('../../examples/binary_classification/binary.test', train)
booster = ctypes.c_void_p() booster = ctypes.c_void_p()
LIB.LGBM_BoosterCreate(train, c_array(ctypes.c_void_p, test), LIB.LGBM_BoosterCreate(train, c_str("app=binary metric=auc num_leaves=31 verbose=0"), ctypes.byref(booster))
len(test), c_str("app=binary metric=auc num_leaves=31 verbose=0"),None, ctypes.byref(booster)) LIB.LGBM_BoosterAddValidData(booster, test)
is_finished = ctypes.c_int(0) is_finished = ctypes.c_int(0)
for i in range(100): for i in range(100):
LIB.LGBM_BoosterUpdateOneIter(booster,ctypes.byref(is_finished)) LIB.LGBM_BoosterUpdateOneIter(booster,ctypes.byref(is_finished))
...@@ -188,7 +188,7 @@ def test_booster(): ...@@ -188,7 +188,7 @@ def test_booster():
LIB.LGBM_BoosterSaveModel(booster, -1, c_str('model.txt')) LIB.LGBM_BoosterSaveModel(booster, -1, c_str('model.txt'))
LIB.LGBM_BoosterFree(booster) LIB.LGBM_BoosterFree(booster)
test_free_dataset(train) test_free_dataset(train)
test_free_dataset(test[0]) test_free_dataset(test)
booster2 = ctypes.c_void_p() booster2 = ctypes.c_void_p()
num_total_model = ctypes.c_long() num_total_model = ctypes.c_long()
LIB.LGBM_BoosterCreateFromModelfile(c_str('model.txt'), ctypes.byref(num_total_model), ctypes.byref(booster2)) LIB.LGBM_BoosterCreateFromModelfile(c_str('model.txt'), ctypes.byref(num_total_model), ctypes.byref(booster2))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment