Commit 6757a6aa authored by Guolin Ke's avatar Guolin Ke
Browse files

limit the max tree output. change hessian in multi-class objective.

parent 3beee91d
......@@ -10,6 +10,7 @@
namespace LightGBM {
#define kMaxTreeOutput (100)
/*!
* \brief Tree model
......@@ -46,8 +47,8 @@ public:
* \return The index of new leaf.
*/
int Split(int leaf, int feature, BinType bin_type, uint32_t threshold, int real_feature,
double threshold_double, double left_value,
double right_value, data_size_t left_cnt, data_size_t right_cnt, double gain);
double threshold_double, double left_value,
double right_value, data_size_t left_cnt, data_size_t right_cnt, double gain);
/*! \brief Get the output of one leaf */
inline double LeafOutput(int leaf) const { return leaf_value_[leaf]; }
......@@ -63,9 +64,9 @@ public:
* \param num_data Number of total data
* \param score Will add prediction to score
*/
void AddPredictionToScore(const Dataset* data,
data_size_t num_data,
double* score) const;
void AddPredictionToScore(const Dataset* data,
data_size_t num_data,
double* score) const;
/*!
* \brief Adding prediction value of this tree model to scorese
......@@ -79,7 +80,7 @@ public:
data_size_t num_data, double* score) const;
/*!
* \brief Prediction on one record
* \brief Prediction on one record
* \param feature_values Feature value of this record
* \return Prediction result
*/
......@@ -101,9 +102,11 @@ public:
* \param rate The factor of shrinkage
*/
inline void Shrinkage(double rate) {
#pragma omp parallel for schedule(static)
#pragma omp parallel for schedule(static)
for (int i = 0; i < num_leaves_; ++i) {
leaf_value_[i] *= rate;
if (leaf_value_[i] > kMaxTreeOutput) { leaf_value_[i] = kMaxTreeOutput; }
else if (leaf_value_[i] < -kMaxTreeOutput) { leaf_value_[i] = -kMaxTreeOutput; }
}
shrinkage_ *= rate;
}
......@@ -216,8 +219,8 @@ inline int Tree::GetLeaf(const double* feature_values) const {
int node = 0;
while (node >= 0) {
if (decision_funs[decision_type_[node]](
feature_values[split_feature_[node]],
threshold_[node])) {
feature_values[split_feature_[node]],
threshold_[node])) {
node = left_child_[node];
} else {
node = right_child_[node];
......
......@@ -37,10 +37,6 @@ public:
}
}
Log::Info("Number of positive: %d, number of negative: %d", cnt_positive, cnt_negative);
// cannot continue if all sample are same class
if (cnt_positive == 0 || cnt_negative == 0) {
Log::Fatal("Training data only contains one class");
}
// use -1 for negative class, and 1 for positive class
label_val_[0] = -1;
label_val_[1] = 1;
......@@ -48,7 +44,7 @@ public:
label_weights_[0] = 1.0f;
label_weights_[1] = 1.0f;
// if using unbalance, change the labels weight
if (is_unbalance_) {
if (is_unbalance_ && cnt_positive > 0 && cnt_negative > 0) {
if (cnt_positive > cnt_negative) {
label_weights_[1] = 1.0f;
label_weights_[0] = static_cast<double>(cnt_positive) / cnt_negative;
......
......@@ -61,10 +61,10 @@ public:
size_t idx = static_cast<size_t>(num_data_) * k + i;
if (label_int_[i] == k) {
gradients[idx] = static_cast<score_t>(p - 1.0f) * label_pos_weights_[k];
hessians[idx] = static_cast<score_t>(2.0f * p * (1.0f - p))* label_pos_weights_[k];
hessians[idx] = static_cast<score_t>(p * (1.0f - p))* label_pos_weights_[k];
} else {
gradients[idx] = static_cast<score_t>(p);
hessians[idx] = static_cast<score_t>(2.0f * p * (1.0f - p));
hessians[idx] = static_cast<score_t>(p * (1.0f - p));
}
}
}
......@@ -82,10 +82,10 @@ public:
size_t idx = static_cast<size_t>(num_data_) * k + i;
if (label_int_[i] == k) {
gradients[idx] = static_cast<score_t>((p - 1.0f) * weights_[i]) * label_pos_weights_[k];
hessians[idx] = static_cast<score_t>(2.0f * p * (1.0f - p) * weights_[i]) * label_pos_weights_[k];
hessians[idx] = static_cast<score_t>(p * (1.0f - p) * weights_[i]) * label_pos_weights_[k];
} else {
gradients[idx] = static_cast<score_t>(p * weights_[i]);
hessians[idx] = static_cast<score_t>(2.0f * p * (1.0f - p) * weights_[i]);
hessians[idx] = static_cast<score_t>(p * (1.0f - p) * weights_[i]);
}
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment