Commit 6757a6aa authored by Guolin Ke's avatar Guolin Ke
Browse files

limit the max tree output. change hessian in multi-class objective.

parent 3beee91d
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
namespace LightGBM { namespace LightGBM {
#define kMaxTreeOutput (100)
/*! /*!
* \brief Tree model * \brief Tree model
...@@ -46,8 +47,8 @@ public: ...@@ -46,8 +47,8 @@ public:
* \return The index of new leaf. * \return The index of new leaf.
*/ */
int Split(int leaf, int feature, BinType bin_type, uint32_t threshold, int real_feature, int Split(int leaf, int feature, BinType bin_type, uint32_t threshold, int real_feature,
double threshold_double, double left_value, double threshold_double, double left_value,
double right_value, data_size_t left_cnt, data_size_t right_cnt, double gain); double right_value, data_size_t left_cnt, data_size_t right_cnt, double gain);
/*! \brief Get the output of one leaf */ /*! \brief Get the output of one leaf */
inline double LeafOutput(int leaf) const { return leaf_value_[leaf]; } inline double LeafOutput(int leaf) const { return leaf_value_[leaf]; }
...@@ -63,9 +64,9 @@ public: ...@@ -63,9 +64,9 @@ public:
* \param num_data Number of total data * \param num_data Number of total data
* \param score Will add prediction to score * \param score Will add prediction to score
*/ */
void AddPredictionToScore(const Dataset* data, void AddPredictionToScore(const Dataset* data,
data_size_t num_data, data_size_t num_data,
double* score) const; double* score) const;
/*! /*!
* \brief Adding prediction value of this tree model to scorese * \brief Adding prediction value of this tree model to scorese
...@@ -79,7 +80,7 @@ public: ...@@ -79,7 +80,7 @@ public:
data_size_t num_data, double* score) const; data_size_t num_data, double* score) const;
/*! /*!
* \brief Prediction on one record * \brief Prediction on one record
* \param feature_values Feature value of this record * \param feature_values Feature value of this record
* \return Prediction result * \return Prediction result
*/ */
...@@ -101,9 +102,11 @@ public: ...@@ -101,9 +102,11 @@ public:
* \param rate The factor of shrinkage * \param rate The factor of shrinkage
*/ */
inline void Shrinkage(double rate) { inline void Shrinkage(double rate) {
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (int i = 0; i < num_leaves_; ++i) { for (int i = 0; i < num_leaves_; ++i) {
leaf_value_[i] *= rate; leaf_value_[i] *= rate;
if (leaf_value_[i] > kMaxTreeOutput) { leaf_value_[i] = kMaxTreeOutput; }
else if (leaf_value_[i] < -kMaxTreeOutput) { leaf_value_[i] = -kMaxTreeOutput; }
} }
shrinkage_ *= rate; shrinkage_ *= rate;
} }
...@@ -216,8 +219,8 @@ inline int Tree::GetLeaf(const double* feature_values) const { ...@@ -216,8 +219,8 @@ inline int Tree::GetLeaf(const double* feature_values) const {
int node = 0; int node = 0;
while (node >= 0) { while (node >= 0) {
if (decision_funs[decision_type_[node]]( if (decision_funs[decision_type_[node]](
feature_values[split_feature_[node]], feature_values[split_feature_[node]],
threshold_[node])) { threshold_[node])) {
node = left_child_[node]; node = left_child_[node];
} else { } else {
node = right_child_[node]; node = right_child_[node];
......
...@@ -37,10 +37,6 @@ public: ...@@ -37,10 +37,6 @@ public:
} }
} }
Log::Info("Number of positive: %d, number of negative: %d", cnt_positive, cnt_negative); Log::Info("Number of positive: %d, number of negative: %d", cnt_positive, cnt_negative);
// cannot continue if all sample are same class
if (cnt_positive == 0 || cnt_negative == 0) {
Log::Fatal("Training data only contains one class");
}
// use -1 for negative class, and 1 for positive class // use -1 for negative class, and 1 for positive class
label_val_[0] = -1; label_val_[0] = -1;
label_val_[1] = 1; label_val_[1] = 1;
...@@ -48,7 +44,7 @@ public: ...@@ -48,7 +44,7 @@ public:
label_weights_[0] = 1.0f; label_weights_[0] = 1.0f;
label_weights_[1] = 1.0f; label_weights_[1] = 1.0f;
// if using unbalance, change the labels weight // if using unbalance, change the labels weight
if (is_unbalance_) { if (is_unbalance_ && cnt_positive > 0 && cnt_negative > 0) {
if (cnt_positive > cnt_negative) { if (cnt_positive > cnt_negative) {
label_weights_[1] = 1.0f; label_weights_[1] = 1.0f;
label_weights_[0] = static_cast<double>(cnt_positive) / cnt_negative; label_weights_[0] = static_cast<double>(cnt_positive) / cnt_negative;
......
...@@ -61,10 +61,10 @@ public: ...@@ -61,10 +61,10 @@ public:
size_t idx = static_cast<size_t>(num_data_) * k + i; size_t idx = static_cast<size_t>(num_data_) * k + i;
if (label_int_[i] == k) { if (label_int_[i] == k) {
gradients[idx] = static_cast<score_t>(p - 1.0f) * label_pos_weights_[k]; gradients[idx] = static_cast<score_t>(p - 1.0f) * label_pos_weights_[k];
hessians[idx] = static_cast<score_t>(2.0f * p * (1.0f - p))* label_pos_weights_[k]; hessians[idx] = static_cast<score_t>(p * (1.0f - p))* label_pos_weights_[k];
} else { } else {
gradients[idx] = static_cast<score_t>(p); gradients[idx] = static_cast<score_t>(p);
hessians[idx] = static_cast<score_t>(2.0f * p * (1.0f - p)); hessians[idx] = static_cast<score_t>(p * (1.0f - p));
} }
} }
} }
...@@ -82,10 +82,10 @@ public: ...@@ -82,10 +82,10 @@ public:
size_t idx = static_cast<size_t>(num_data_) * k + i; size_t idx = static_cast<size_t>(num_data_) * k + i;
if (label_int_[i] == k) { if (label_int_[i] == k) {
gradients[idx] = static_cast<score_t>((p - 1.0f) * weights_[i]) * label_pos_weights_[k]; gradients[idx] = static_cast<score_t>((p - 1.0f) * weights_[i]) * label_pos_weights_[k];
hessians[idx] = static_cast<score_t>(2.0f * p * (1.0f - p) * weights_[i]) * label_pos_weights_[k]; hessians[idx] = static_cast<score_t>(p * (1.0f - p) * weights_[i]) * label_pos_weights_[k];
} else { } else {
gradients[idx] = static_cast<score_t>(p * weights_[i]); gradients[idx] = static_cast<score_t>(p * weights_[i]);
hessians[idx] = static_cast<score_t>(2.0f * p * (1.0f - p) * weights_[i]); hessians[idx] = static_cast<score_t>(p * (1.0f - p) * weights_[i]);
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment