Commit 14a67b7e authored by Guolin Ke's avatar Guolin Ke
Browse files

support dynamic change training data and add validation data

parent 13329682
......@@ -51,6 +51,18 @@ public:
explicit BinMapper(const void* memory);
~BinMapper();
bool CheckAlign(const BinMapper& other) const {
if (num_bin_ != other.num_bin_) {
return false;
}
for (int i = 0; i < num_bin_; ++i) {
if (bin_upper_bound_[i] != other.bin_upper_bound_[i]) {
return false;
}
}
return true;
}
/*! \brief Get number of bins */
inline int num_bin() const { return num_bin_; }
/*! \brief True if bin is trival (contains only one bin) */
......
......@@ -41,12 +41,20 @@ public:
*/
virtual void ResetConfig(const BoostingConfig* config) = 0;
/*!
* \brief Reset training data for current boosting
* \param train_data Training data
* \param object_function Training objective function
* \param training_metrics Training metric
*/
virtual void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector<const Metric*>& training_metrics) = 0;
/*!
* \brief Add a validation data
* \param valid_data Validation data
* \param valid_metrics Metric for validation data
*/
virtual void AddDataset(const Dataset* valid_data,
virtual void AddValidDataset(const Dataset* valid_data,
const std::vector<const Metric*>& valid_metrics) = 0;
/*!
......
......@@ -212,19 +212,12 @@ DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
/*!
* \brief create an new boosting learner
* \param train_data training data set
* \param valid_datas validation data sets
* \param valid_names names of validation data sets
* \param n_valid_datas number of validation set
* \param parameters format: 'key1=value1 key2=value2'
* \param init_model_filename filename of model
* \prama out handle of created Booster
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
const DatesetHandle valid_datas[],
int n_valid_datas,
const char* parameters,
const char* init_model_filename,
BoosterHandle* out);
/*!
......@@ -247,6 +240,22 @@ DllExport int LGBM_BoosterCreateFromModelfile(
*/
DllExport int LGBM_BoosterFree(BoosterHandle handle);
/*!
* \brief Add new validation to booster
* \param valid_data validation data set
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterAddValidData(BoosterHandle handle,
const DatesetHandle valid_data);
/*!
* \brief Add new validation to booster
* \param train_data training data set
* \return 0 when success, -1 when failure happens
*/
DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle,
const DatesetHandle train_data);
/*!
* \brief Reset config for current booster
* \param parameters format: 'key1=value1 key2=value2'
......
......@@ -277,6 +277,27 @@ public:
/*! \brief Destructor */
~Dataset();
bool CheckAlign(const Dataset& other) const {
if (num_features_ != other.num_features_) {
return false;
}
if (num_total_features_ != other.num_total_features_) {
return false;
}
if (num_class_ != other.num_class_) {
return false;
}
if (label_idx_ != other.label_idx_) {
return false;
}
for (int i = 0; i < num_features_; ++i) {
if (!features_[i]->CheckAlign(*(other.features_[i].get()))) {
return false;
}
}
return true;
}
inline void PushOneRow(int tid, data_size_t row_idx, const std::vector<double>& feature_values) {
for (size_t i = 0; i < feature_values.size() && i < static_cast<size_t>(num_total_features_); ++i) {
int feature_idx = used_feature_map_[i];
......
......@@ -63,6 +63,13 @@ public:
~Feature() {
}
bool CheckAlign(const Feature& other) const {
if (feature_index_ != other.feature_index_) {
return false;
}
return bin_mapper_->CheckAlign(*(other.bin_mapper_.get()));
}
/*!
* \brief Push one record, will auto convert to bin and push to bin data
* \param tid Thread id
......
......@@ -207,7 +207,7 @@ void Application::InitTrain() {
Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
// add validation data into boosting
for (size_t i = 0; i < valid_datas_.size(); ++i) {
boosting_->AddDataset(valid_datas_[i].get(),
boosting_->AddValidDataset(valid_datas_[i].get(),
Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_[i]));
}
Log::Info("Finished initializing training");
......
......@@ -16,7 +16,10 @@
namespace LightGBM {
GBDT::GBDT() : saved_model_size_(-1), num_iteration_for_pred_(0) {
GBDT::GBDT()
:saved_model_size_(-1),
num_iteration_for_pred_(0),
num_init_iteration_(0) {
}
......@@ -33,8 +36,46 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
max_feature_idx_ = 0;
early_stopping_round_ = gbdt_config_->early_stopping_round;
shrinkage_rate_ = gbdt_config_->learning_rate;
train_data_ = train_data;
num_class_ = config->num_class;
train_data_ = nullptr;
ResetTrainingData(train_data, object_function, training_metrics);
// initialize random generator
random_ = Random(gbdt_config_->bagging_seed);
}
void GBDT::ResetConfig(const BoostingConfig* config) {
gbdt_config_ = config;
early_stopping_round_ = gbdt_config_->early_stopping_round;
shrinkage_rate_ = gbdt_config_->learning_rate;
// create tree learner
tree_learner_.clear();
for (int i = 0; i < num_class_; ++i) {
auto new_tree_learner = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->tree_config));
new_tree_learner->Init(train_data_);
// init tree learner
tree_learner_.push_back(std::move(new_tree_learner));
}
tree_learner_.shrink_to_fit();
// if need bagging, create buffer
if (gbdt_config_->bagging_fraction < 1.0 && gbdt_config_->bagging_freq > 0) {
out_of_bag_data_indices_ = std::vector<data_size_t>(num_data_);
bag_data_indices_ = std::vector<data_size_t>(num_data_);
} else {
out_of_bag_data_cnt_ = 0;
out_of_bag_data_indices_.clear();
bag_data_cnt_ = num_data_;
bag_data_indices_.clear();
}
// initialize random generator
random_ = Random(gbdt_config_->bagging_seed);
}
void GBDT::ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector<const Metric*>& training_metrics) {
if (train_data_ != nullptr && !train_data_->CheckAlign(*train_data)) {
Log::Fatal("cannot reset training data, since new training data has different bin mappers");
}
train_data_ = train_data;
// create tree learner
tree_learner_.clear();
for (int i = 0; i < num_class_; ++i) {
......@@ -46,6 +87,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
tree_learner_.shrink_to_fit();
object_function_ = object_function;
// push training metrics
training_metrics_.clear();
for (const auto& metric : training_metrics) {
training_metrics_.push_back(metric);
}
......@@ -59,7 +101,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
hessians_ = std::vector<score_t>(num_data_ * num_class_);
}
sigmoid_ = -1.0f;
if (object_function_ != nullptr
if (object_function_ != nullptr
&& std::string(object_function_->GetName()) == std::string("binary")) {
// only binary classification need sigmoid transform
sigmoid_ = gbdt_config_->sigmoid;
......@@ -78,44 +120,29 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
bag_data_cnt_ = num_data_;
bag_data_indices_.clear();
}
// initialize random generator
random_ = Random(gbdt_config_->bagging_seed);
}
void GBDT::ResetConfig(const BoostingConfig* config) {
gbdt_config_ = config;
early_stopping_round_ = gbdt_config_->early_stopping_round;
shrinkage_rate_ = gbdt_config_->learning_rate;
// create tree learner
tree_learner_.clear();
for (int i = 0; i < num_class_; ++i) {
auto new_tree_learner = std::unique_ptr<TreeLearner>(TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->tree_config));
new_tree_learner->Init(train_data_);
// init tree learner
tree_learner_.push_back(std::move(new_tree_learner));
}
tree_learner_.shrink_to_fit();
// if need bagging, create buffer
if (gbdt_config_->bagging_fraction < 1.0 && gbdt_config_->bagging_freq > 0) {
out_of_bag_data_indices_ = std::vector<data_size_t>(num_data_);
bag_data_indices_ = std::vector<data_size_t>(num_data_);
} else {
out_of_bag_data_cnt_ = 0;
out_of_bag_data_indices_.clear();
bag_data_cnt_ = num_data_;
bag_data_indices_.clear();
// update score
for (int i = 0; i < iter_; ++i) {
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
auto curr_tree = i * num_class_ + curr_class;
train_score_updater_->AddScore(models_[curr_tree].get(), curr_class);
}
}
// initialize random generator
random_ = Random(gbdt_config_->bagging_seed);
}
void GBDT::AddDataset(const Dataset* valid_data,
void GBDT::AddValidDataset(const Dataset* valid_data,
const std::vector<const Metric*>& valid_metrics) {
if (iter_ > 0) {
Log::Fatal("Cannot add validation data after training started");
if (!train_data_->CheckAlign(*valid_data)) {
Log::Fatal("cannot add validation data, since it has different bin mappers with training data");
}
// for a validation dataset, we need its score and metric
auto new_score_updater = std::unique_ptr<ScoreUpdater>(new ScoreUpdater(valid_data, num_class_));
// update score
for (int i = 0; i < iter_; ++i) {
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
auto curr_tree = i * num_class_ + curr_class;
new_score_updater->AddScore(models_[curr_tree].get(), curr_class);
}
}
valid_score_updater_.push_back(std::move(new_score_updater));
valid_metrics_.emplace_back();
if (early_stopping_round_ > 0) {
......@@ -499,6 +526,7 @@ void GBDT::LoadModelFromString(const std::string& model_str) {
}
Log::Info("Finished loading %d models", models_.size());
num_iteration_for_pred_ = static_cast<int>(models_.size()) / num_class_;
num_init_iteration_ = num_iteration_for_pred_;
}
std::string GBDT::FeatureImportance() const {
......
......@@ -42,12 +42,20 @@ public:
*/
void ResetConfig(const BoostingConfig* config) override;
/*!
* \brief Reset training data for current boosting
* \param train_data Training data
* \param object_function Training objective function
* \param training_metrics Training metric
*/
void ResetTrainingData(const Dataset* train_data, const ObjectiveFunction* object_function, const std::vector<const Metric*>& training_metrics) override;
/*!
* \brief Adding a validation dataset
* \param valid_data Validation dataset
* \param valid_metrics Metrics for validation dataset
*/
void AddDataset(const Dataset* valid_data,
void AddValidDataset(const Dataset* valid_data,
const std::vector<const Metric*>& valid_metrics) override;
/*!
* \brief Training logic
......@@ -63,7 +71,7 @@ public:
*/
void RollbackOneIter() override;
int GetCurrentIteration() const override { return iter_; }
int GetCurrentIteration() const override { return iter_ + num_init_iteration_; }
bool EvalAndCheckEarlyStopping() override;
......@@ -256,6 +264,7 @@ protected:
int num_iteration_for_pred_;
/*! \brief Shrinkage rate for one iteration */
double shrinkage_rate_;
int num_init_iteration_;
};
} // namespace LightGBM
......
......@@ -28,9 +28,7 @@ public:
}
Booster(const Dataset* train_data,
std::vector<const Dataset*> valid_data,
const char* parameters)
:train_data_(train_data), valid_datas_(valid_data) {
const char* parameters) {
config_.LoadFromString(parameters);
// create boosting
if (config_.io_config.input_model.size() > 0) {
......@@ -38,6 +36,17 @@ public:
please use continued train with input score");
}
boosting_.reset(Boosting::CreateBoosting(config_.boosting_type, ""));
ConstructObjectAndTrainingMetrics(train_data);
// initialize the boosting
boosting_->Init(&config_.boosting_config, train_data, objective_fun_.get(),
Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
}
~Booster() {
}
void ConstructObjectAndTrainingMetrics(const Dataset* train_data) {
// create objective function
objective_fun_.reset(ObjectiveFunction::CreateObjectiveFunction(config_.objective_type,
config_.objective_config));
......@@ -45,48 +54,39 @@ public:
Log::Warning("Using self-defined objective functions");
}
// create training metric
train_metric_.clear();
for (auto metric_type : config_.metric_types) {
auto metric = std::unique_ptr<Metric>(
Metric::CreateMetric(metric_type, config_.metric_config));
if (metric == nullptr) { continue; }
metric->Init(train_data_->metadata(), train_data_->num_data());
metric->Init(train_data->metadata(), train_data->num_data());
train_metric_.push_back(std::move(metric));
}
train_metric_.shrink_to_fit();
// add metric for validation data
for (size_t i = 0; i < valid_datas_.size(); ++i) {
valid_metrics_.emplace_back();
for (auto metric_type : config_.metric_types) {
auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
if (metric == nullptr) { continue; }
metric->Init(valid_datas_[i]->metadata(), valid_datas_[i]->num_data());
valid_metrics_.back().push_back(std::move(metric));
}
valid_metrics_.back().shrink_to_fit();
}
valid_metrics_.shrink_to_fit();
// initialize the objective function
if (objective_fun_ != nullptr) {
objective_fun_->Init(train_data_->metadata(), train_data_->num_data());
}
// initialize the boosting
boosting_->Init(&config_.boosting_config, train_data_, objective_fun_.get(),
Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
// add validation data into boosting
for (size_t i = 0; i < valid_datas_.size(); ++i) {
boosting_->AddDataset(valid_datas_[i],
Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_[i]));
objective_fun_->Init(train_data->metadata(), train_data->num_data());
}
}
void LoadModelFromFile(const char* filename) {
Boosting::LoadFileToBoosting(boosting_.get(), filename);
void ResetTrainingData(const Dataset* train_data) {
ConstructObjectAndTrainingMetrics(train_data);
// initialize the boosting
boosting_->ResetTrainingData(train_data, objective_fun_.get(), Common::ConstPtrInVectorWrapper<Metric>(train_metric_));
}
~Booster() {
void AddValidData(const Dataset* valid_data) {
valid_metrics_.emplace_back();
for (auto metric_type : config_.metric_types) {
auto metric = std::unique_ptr<Metric>(Metric::CreateMetric(metric_type, config_.metric_config));
if (metric == nullptr) { continue; }
metric->Init(valid_data->metadata(), valid_data->num_data());
valid_metrics_.back().push_back(std::move(metric));
}
valid_metrics_.back().shrink_to_fit();
boosting_->AddValidDataset(valid_data,
Common::ConstPtrInVectorWrapper<Metric>(valid_metrics_.back()));
}
bool TrainOneIter() {
return boosting_->TrainOneIter(nullptr, nullptr, false);
}
......@@ -151,9 +151,7 @@ public:
}
void ResetBoostingConfig(const char* parameters) {
OverallConfig new_config;
new_config.LoadFromString(parameters);
config_.boosting_config = new_config.boosting_config;
config_.LoadFromString(parameters);
boosting_->ResetConfig(&config_.boosting_config);
}
......@@ -164,14 +162,9 @@ public:
const Boosting* GetBoosting() const { return boosting_.get(); }
private:
std::unique_ptr<Boosting> boosting_;
/*! \brief All configs */
OverallConfig config_;
/*! \brief Training data */
const Dataset* train_data_;
/*! \brief Validation data */
std::vector<const Dataset*> valid_datas_;
/*! \brief Metric for training data */
std::vector<std::unique_ptr<Metric>> train_metric_;
/*! \brief Metrics for validation data */
......@@ -446,21 +439,11 @@ DllExport int LGBM_DatasetGetNumFeature(DatesetHandle handle,
// ---- start of booster
DllExport int LGBM_BoosterCreate(const DatesetHandle train_data,
const DatesetHandle valid_datas[],
int n_valid_datas,
const char* parameters,
const char* init_model_filename,
BoosterHandle* out) {
API_BEGIN();
const Dataset* p_train_data = reinterpret_cast<const Dataset*>(train_data);
std::vector<const Dataset*> p_valid_datas;
for (int i = 0; i < n_valid_datas; ++i) {
p_valid_datas.emplace_back(reinterpret_cast<const Dataset*>(valid_datas[i]));
}
auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, p_valid_datas, parameters));
if (init_model_filename != nullptr) {
ret->LoadModelFromFile(init_model_filename);
}
auto ret = std::unique_ptr<Booster>(new Booster(p_train_data, parameters));
*out = ret.release();
API_END();
}
......@@ -482,6 +465,25 @@ DllExport int LGBM_BoosterFree(BoosterHandle handle) {
API_END();
}
DllExport int LGBM_BoosterAddValidData(BoosterHandle handle,
const DatesetHandle valid_data) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
const Dataset* p_dataset = reinterpret_cast<const Dataset*>(valid_data);
ref_booster->AddValidData(p_dataset);
API_END();
}
DllExport int LGBM_BoosterResetTrainingData(BoosterHandle handle,
const DatesetHandle train_data) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
const Dataset* p_dataset = reinterpret_cast<const Dataset*>(train_data);
ref_booster->ResetTrainingData(p_dataset);
API_END();
}
DllExport int LGBM_BoosterResetParameter(BoosterHandle handle, const char* parameters) {
API_BEGIN();
Booster* ref_booster = reinterpret_cast<Booster*>(handle);
......
......@@ -174,10 +174,10 @@ def test_dataset():
test_free_dataset(train)
def test_booster():
train = test_load_from_mat('../../examples/binary_classification/binary.train', None)
test = [test_load_from_mat('../../examples/binary_classification/binary.test', train)]
test = test_load_from_mat('../../examples/binary_classification/binary.test', train)
booster = ctypes.c_void_p()
LIB.LGBM_BoosterCreate(train, c_array(ctypes.c_void_p, test),
len(test), c_str("app=binary metric=auc num_leaves=31 verbose=0"),None, ctypes.byref(booster))
LIB.LGBM_BoosterCreate(train, c_str("app=binary metric=auc num_leaves=31 verbose=0"), ctypes.byref(booster))
LIB.LGBM_BoosterAddValidData(booster, test)
is_finished = ctypes.c_int(0)
for i in range(100):
LIB.LGBM_BoosterUpdateOneIter(booster,ctypes.byref(is_finished))
......@@ -188,7 +188,7 @@ def test_booster():
LIB.LGBM_BoosterSaveModel(booster, -1, c_str('model.txt'))
LIB.LGBM_BoosterFree(booster)
test_free_dataset(train)
test_free_dataset(test[0])
test_free_dataset(test)
booster2 = ctypes.c_void_p()
num_total_model = ctypes.c_long()
LIB.LGBM_BoosterCreateFromModelfile(c_str('model.txt'), ctypes.byref(num_total_model), ctypes.byref(booster2))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment