Commit 6757a6aa authored by Guolin Ke's avatar Guolin Ke
Browse files

limit the max tree output. change hessian in multi-class objective.

parent 3beee91d
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
namespace LightGBM { namespace LightGBM {
#define kMaxTreeOutput (100)
/*! /*!
* \brief Tree model * \brief Tree model
...@@ -101,9 +102,11 @@ public: ...@@ -101,9 +102,11 @@ public:
* \param rate The factor of shrinkage * \param rate The factor of shrinkage
*/ */
inline void Shrinkage(double rate) { inline void Shrinkage(double rate) {
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (int i = 0; i < num_leaves_; ++i) { for (int i = 0; i < num_leaves_; ++i) {
leaf_value_[i] *= rate; leaf_value_[i] *= rate;
if (leaf_value_[i] > kMaxTreeOutput) { leaf_value_[i] = kMaxTreeOutput; }
else if (leaf_value_[i] < -kMaxTreeOutput) { leaf_value_[i] = -kMaxTreeOutput; }
} }
shrinkage_ *= rate; shrinkage_ *= rate;
} }
......
...@@ -37,10 +37,6 @@ public: ...@@ -37,10 +37,6 @@ public:
} }
} }
Log::Info("Number of positive: %d, number of negative: %d", cnt_positive, cnt_negative); Log::Info("Number of positive: %d, number of negative: %d", cnt_positive, cnt_negative);
// cannot continue if all sample are same class
if (cnt_positive == 0 || cnt_negative == 0) {
Log::Fatal("Training data only contains one class");
}
// use -1 for negative class, and 1 for positive class // use -1 for negative class, and 1 for positive class
label_val_[0] = -1; label_val_[0] = -1;
label_val_[1] = 1; label_val_[1] = 1;
...@@ -48,7 +44,7 @@ public: ...@@ -48,7 +44,7 @@ public:
label_weights_[0] = 1.0f; label_weights_[0] = 1.0f;
label_weights_[1] = 1.0f; label_weights_[1] = 1.0f;
// if using unbalance, change the labels weight // if using unbalance, change the labels weight
if (is_unbalance_) { if (is_unbalance_ && cnt_positive > 0 && cnt_negative > 0) {
if (cnt_positive > cnt_negative) { if (cnt_positive > cnt_negative) {
label_weights_[1] = 1.0f; label_weights_[1] = 1.0f;
label_weights_[0] = static_cast<double>(cnt_positive) / cnt_negative; label_weights_[0] = static_cast<double>(cnt_positive) / cnt_negative;
......
...@@ -61,10 +61,10 @@ public: ...@@ -61,10 +61,10 @@ public:
size_t idx = static_cast<size_t>(num_data_) * k + i; size_t idx = static_cast<size_t>(num_data_) * k + i;
if (label_int_[i] == k) { if (label_int_[i] == k) {
gradients[idx] = static_cast<score_t>(p - 1.0f) * label_pos_weights_[k]; gradients[idx] = static_cast<score_t>(p - 1.0f) * label_pos_weights_[k];
hessians[idx] = static_cast<score_t>(2.0f * p * (1.0f - p))* label_pos_weights_[k]; hessians[idx] = static_cast<score_t>(p * (1.0f - p))* label_pos_weights_[k];
} else { } else {
gradients[idx] = static_cast<score_t>(p); gradients[idx] = static_cast<score_t>(p);
hessians[idx] = static_cast<score_t>(2.0f * p * (1.0f - p)); hessians[idx] = static_cast<score_t>(p * (1.0f - p));
} }
} }
} }
...@@ -82,10 +82,10 @@ public: ...@@ -82,10 +82,10 @@ public:
size_t idx = static_cast<size_t>(num_data_) * k + i; size_t idx = static_cast<size_t>(num_data_) * k + i;
if (label_int_[i] == k) { if (label_int_[i] == k) {
gradients[idx] = static_cast<score_t>((p - 1.0f) * weights_[i]) * label_pos_weights_[k]; gradients[idx] = static_cast<score_t>((p - 1.0f) * weights_[i]) * label_pos_weights_[k];
hessians[idx] = static_cast<score_t>(2.0f * p * (1.0f - p) * weights_[i]) * label_pos_weights_[k]; hessians[idx] = static_cast<score_t>(p * (1.0f - p) * weights_[i]) * label_pos_weights_[k];
} else { } else {
gradients[idx] = static_cast<score_t>(p * weights_[i]); gradients[idx] = static_cast<score_t>(p * weights_[i]);
hessians[idx] = static_cast<score_t>(2.0f * p * (1.0f - p) * weights_[i]); hessians[idx] = static_cast<score_t>(p * (1.0f - p) * weights_[i]);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment