Commit 96cba416 authored by Guolin Ke's avatar Guolin Ke Committed by GitHub
Browse files

Adapt learning rate to DART and xgboost dart mode(#139)

parent 1cacaef9
...@@ -73,9 +73,16 @@ The parameter format is ```key1=value1 key2=value2 ... ``` . And parameters can ...@@ -73,9 +73,16 @@ The parameter format is ```key1=value1 key2=value2 ... ``` . And parameters can
* l2 regularization * l2 regularization
* ```min_gain_to_split``` , default=```0```, type=double * ```min_gain_to_split``` , default=```0```, type=double
* The minimal gain to perform split * The minimal gain to perform split
* ```drop_rate```, default=```0.01```, type=double * ```drop_rate```, default=```0.1```, type=double
* only used in ```dart```, will drop ```drop_rate*current_num_models``` before boosting. * only used in ```dart```
* If you want to use ```skip_rate``` like in xgboost, you can use [callbacks](Python-API.md#callbacks) with changing ```drop_rate```. * ```skip_drop```, default=```0.5```, type=double
* only used in ```dart```, probability of skipping drop
* ```max_drop```, default=```50```, type=int
* only used in ```dart```, max number of dropped trees on one iteration.
* ```uniform_drop```, default=```false```, type=bool
* only used in ```dart```, true if want to use uniform drop
* ```xgboost_dart_mode```, default=```false```, type=bool
* only used in ```dart```, true if want to use xgboost dart mode
* ```drop_seed```, default=```4```, type=int * ```drop_seed```, default=```4```, type=int
* only used in ```dart```, used to random seed to choose dropping models. * only used in ```dart```, used to random seed to choose dropping models.
......
...@@ -205,7 +205,11 @@ public: ...@@ -205,7 +205,11 @@ public:
int bagging_freq = 0; int bagging_freq = 0;
int early_stopping_round = 0; int early_stopping_round = 0;
int num_class = 1; int num_class = 1;
double drop_rate = 0.01; double drop_rate = 0.1;
int max_drop = 50;
double skip_drop = 0.5;
bool xgboost_dart_mode = false;
bool uniform_drop = false;
int drop_seed = 4; int drop_seed = 4;
TreeLearnerType tree_learner_type = TreeLearnerType::kSerialTreeLearner; TreeLearnerType tree_learner_type = TreeLearnerType::kSerialTreeLearner;
TreeConfig tree_config; TreeConfig tree_config;
......
...@@ -114,10 +114,10 @@ def param_dict_to_str(data): ...@@ -114,10 +114,10 @@ def param_dict_to_str(data):
return "" return ""
pairs = [] pairs = []
for key, val in data.items(): for key, val in data.items():
if is_str(val) or isinstance(val, (int, float, bool, np.integer, np.float, np.float32)): if isinstance(val, (list, tuple, set)) or is_numpy_1d_array(val):
pairs.append(str(key)+'='+str(val))
elif isinstance(val, (list, tuple, set)):
pairs.append(str(key)+'='+','.join(map(str, val))) pairs.append(str(key)+'='+','.join(map(str, val)))
elif is_str(val) or isinstance(val, (int, float, bool)) or is_numpy_object(val):
pairs.append(str(key)+'='+str(val))
else: else:
raise TypeError('Unknown type of parameter:%s, got:%s' raise TypeError('Unknown type of parameter:%s, got:%s'
% (key, type(val).__name__)) % (key, type(val).__name__))
......
...@@ -19,7 +19,7 @@ public: ...@@ -19,7 +19,7 @@ public:
/*! /*!
* \brief Constructor * \brief Constructor
*/ */
DART(): GBDT() { } DART() : GBDT() { }
/*! /*!
* \brief Destructor * \brief Destructor
*/ */
...@@ -36,6 +36,7 @@ public: ...@@ -36,6 +36,7 @@ public:
const std::vector<const Metric*>& training_metrics) override { const std::vector<const Metric*>& training_metrics) override {
GBDT::Init(config, train_data, object_function, training_metrics); GBDT::Init(config, train_data, object_function, training_metrics);
random_for_drop_ = Random(gbdt_config_->drop_seed); random_for_drop_ = Random(gbdt_config_->drop_seed);
sum_weight_ = 0.0f;
} }
/*! /*!
* \brief one training iteration * \brief one training iteration
...@@ -45,6 +46,10 @@ public: ...@@ -45,6 +46,10 @@ public:
GBDT::TrainOneIter(gradient, hessian, false); GBDT::TrainOneIter(gradient, hessian, false);
// normalize // normalize
Normalize(); Normalize();
if (!gbdt_config_->uniform_drop) {
tree_weight_.push_back(shrinkage_rate_);
sum_weight_ += shrinkage_rate_;
}
if (is_eval) { if (is_eval) {
return EvalAndCheckEarlyStopping(); return EvalAndCheckEarlyStopping();
} else { } else {
...@@ -55,7 +60,6 @@ public: ...@@ -55,7 +60,6 @@ public:
void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function, void ResetTrainingData(const BoostingConfig* config, const Dataset* train_data, const ObjectiveFunction* object_function,
const std::vector<const Metric*>& training_metrics) { const std::vector<const Metric*>& training_metrics) {
GBDT::ResetTrainingData(config, train_data, object_function, training_metrics); GBDT::ResetTrainingData(config, train_data, object_function, training_metrics);
shrinkage_rate_ = gbdt_config_->learning_rate / (gbdt_config_->learning_rate + static_cast<double>(drop_index_.size()));
} }
/*! /*!
...@@ -84,19 +88,31 @@ private: ...@@ -84,19 +88,31 @@ private:
*/ */
void DroppingTrees() { void DroppingTrees() {
drop_index_.clear(); drop_index_.clear();
// select dropping tree indexes based on drop_rate bool is_skip = random_for_drop_.NextDouble() < gbdt_config_->skip_drop;
// if drop rate is too small, skip this step, drop one tree randomly // select dropping tree indexes based on drop_rate and tree weights
if (gbdt_config_->drop_rate > kEpsilon) { if (!is_skip) {
for (int i = 0; i < iter_; ++i) { double drop_rate = gbdt_config_->drop_rate;
if (random_for_drop_.NextDouble() < gbdt_config_->drop_rate) { if (!gbdt_config_->uniform_drop) {
drop_index_.push_back(i); double inv_average_weight = static_cast<double>(tree_weight_.size()) / sum_weight_;
if (gbdt_config_->max_drop > 0) {
drop_rate = std::min(drop_rate, gbdt_config_->max_drop * inv_average_weight / sum_weight_);
}
for (int i = 0; i < iter_; ++i) {
if (random_for_drop_.NextDouble() < drop_rate * tree_weight_[i] * inv_average_weight) {
drop_index_.push_back(i);
}
}
} else {
if (gbdt_config_->max_drop > 0) {
drop_rate = std::min(drop_rate, gbdt_config_->max_drop / static_cast<double>(iter_));
}
for (int i = 0; i < iter_; ++i) {
if (random_for_drop_.NextDouble() < drop_rate) {
drop_index_.push_back(i);
}
} }
} }
} }
// binomial-plus-one, at least one tree will be dropped
if (drop_index_.empty()) {
drop_index_ = random_for_drop_.Sample(iter_, 1);
}
// drop trees // drop trees
for (auto i : drop_index_) { for (auto i : drop_index_) {
for (int curr_class = 0; curr_class < num_class_; ++curr_class) { for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
...@@ -105,34 +121,70 @@ private: ...@@ -105,34 +121,70 @@ private:
train_score_updater_->AddScore(models_[curr_tree].get(), curr_class); train_score_updater_->AddScore(models_[curr_tree].get(), curr_class);
} }
} }
shrinkage_rate_ = gbdt_config_->learning_rate / (gbdt_config_->learning_rate + static_cast<double>(drop_index_.size())); if (!gbdt_config_->xgboost_dart_mode) {
shrinkage_rate_ = gbdt_config_->learning_rate / (1.0f + static_cast<double>(drop_index_.size()));
} else {
if (drop_index_.empty()) {
shrinkage_rate_ = gbdt_config_->learning_rate;
} else {
shrinkage_rate_ = gbdt_config_->learning_rate / (gbdt_config_->learning_rate + static_cast<double>(drop_index_.size()));
}
}
} }
/*! /*!
* \brief normalize dropped trees * \brief normalize dropped trees
* NOTE: num_drop_tree(k), learning_rate(lr), shrinkage_rate_ = lr / (k + lr) * NOTE: num_drop_tree(k), learning_rate(lr), shrinkage_rate_ = lr / (k + 1)
* step 1: shrink tree to -1 -> drop tree * step 1: shrink tree to -1 -> drop tree
* step 2: shrink tree to k / (k + lr) - 1 from -1 * step 2: shrink tree to k / (k + 1) - 1 from -1, by 1/(k+1)
* -> normalize for valid data * -> normalize for valid data
* step 3: shrink tree to k / (k + lr) from k / (k + lr) - 1 * step 3: shrink tree to k / (k + 1) from k / (k + 1) - 1, by -k
* -> normalize for train data * -> normalize for train data
* end with tree weight = k / (k + lr) * end with tree weight = (k / (k + 1)) * old_weight
*/ */
void Normalize() { void Normalize() {
double k = static_cast<double>(drop_index_.size()); double k = static_cast<double>(drop_index_.size());
for (auto i : drop_index_) { if (!gbdt_config_->xgboost_dart_mode) {
for (int curr_class = 0; curr_class < num_class_; ++curr_class) { for (auto i : drop_index_) {
auto curr_tree = i * num_class_ + curr_class; for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
// update validation score auto curr_tree = i * num_class_ + curr_class;
models_[curr_tree]->Shrinkage(shrinkage_rate_); // update validation score
for (auto& score_updater : valid_score_updater_) { models_[curr_tree]->Shrinkage(1.0f / (k + 1.0f));
score_updater->AddScore(models_[curr_tree].get(), curr_class); for (auto& score_updater : valid_score_updater_) {
score_updater->AddScore(models_[curr_tree].get(), curr_class);
}
// update training score
models_[curr_tree]->Shrinkage(-k);
train_score_updater_->AddScore(models_[curr_tree].get(), curr_class);
}
if (!gbdt_config_->uniform_drop) {
sum_weight_ -= tree_weight_[i] * (1.0f / (k + 1.0f));
tree_weight_[i] *= (k / (k + 1.0f));
}
}
} else {
for (auto i : drop_index_) {
for (int curr_class = 0; curr_class < num_class_; ++curr_class) {
auto curr_tree = i * num_class_ + curr_class;
// update validation score
models_[curr_tree]->Shrinkage(shrinkage_rate_);
for (auto& score_updater : valid_score_updater_) {
score_updater->AddScore(models_[curr_tree].get(), curr_class);
}
// update training score
models_[curr_tree]->Shrinkage(-k / gbdt_config_->learning_rate);
train_score_updater_->AddScore(models_[curr_tree].get(), curr_class);
}
if (!gbdt_config_->uniform_drop) {
sum_weight_ -= tree_weight_[i] * (1.0f / (k + gbdt_config_->learning_rate));;
tree_weight_[i] *= (k / (k + gbdt_config_->learning_rate));
} }
// update training score
models_[curr_tree]->Shrinkage(-k / gbdt_config_->learning_rate);
train_score_updater_->AddScore(models_[curr_tree].get(), curr_class);
} }
} }
} }
/*! \brief The weights of all trees, used to choose drop trees */
std::vector<double> tree_weight_;
/*! \brief sum weights of all trees */
double sum_weight_;
/*! \brief The indexes of dropping trees */ /*! \brief The indexes of dropping trees */
std::vector<int> drop_index_; std::vector<int> drop_index_;
/*! \brief Random generator, used to select dropping trees */ /*! \brief Random generator, used to select dropping trees */
......
...@@ -313,7 +313,12 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par ...@@ -313,7 +313,12 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par
GetInt(params, "num_class", &num_class); GetInt(params, "num_class", &num_class);
GetInt(params, "drop_seed", &drop_seed); GetInt(params, "drop_seed", &drop_seed);
GetDouble(params, "drop_rate", &drop_rate); GetDouble(params, "drop_rate", &drop_rate);
GetDouble(params, "skip_drop", &skip_drop);
GetInt(params, "max_drop", &max_drop);
GetBool(params, "xgboost_dart_mode", &xgboost_dart_mode);
GetBool(params, "uniform_drop", &uniform_drop);
CHECK(drop_rate <= 1.0 && drop_rate >= 0.0); CHECK(drop_rate <= 1.0 && drop_rate >= 0.0);
CHECK(skip_drop <= 1.0 && skip_drop >= 0.0);
GetTreeLearnerType(params); GetTreeLearnerType(params);
tree_config.Set(params); tree_config.Set(params);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment