binary_objective.hpp 3.67 KB
Newer Older
Guolin Ke's avatar
Guolin Ke committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#ifndef LIGHTGBM_OBJECTIVE_BINARY_OBJECTIVE_HPP_
#define LIGHTGBM_OBJECTIVE_BINARY_OBJECTIVE_HPP_

#include <LightGBM/objective_function.h>

#include <cstring>
#include <cmath>

namespace LightGBM {
/*!
* \brief Objective funtion for binary classification
*/
class BinaryLogloss: public ObjectiveFunction {
public:
  explicit BinaryLogloss(const ObjectiveConfig& config) {
    is_unbalance_ = config.is_unbalance;
    sigmoid_ = static_cast<score_t>(config.sigmoid);
    if (sigmoid_ <= 0.0) {
      Log::Stderr("sigmoid param %f should greater than zero", sigmoid_);
    }
  }
  ~BinaryLogloss() {}
  void Init(const Metadata& metadata, data_size_t num_data) override {
    num_data_ = num_data;
    label_ = metadata.label();
    weights_ = metadata.weights();
    data_size_t cnt_positive = 0;
    data_size_t cnt_negative = 0;
    // count for positive and negative samples
    for (data_size_t i = 0; i < num_data_; ++i) {
      if (label_[i] == 1) {
        ++cnt_positive;
      } else {
        ++cnt_negative;
      }
    }
    Log::Stdout("number of postive:%d number of negative:%d", cnt_positive, cnt_negative);
    // cannot continue if all sample are same class
    if (cnt_positive == 0 || cnt_negative == 0) {
      Log::Stderr("input training data only contain one class");
    }
    // use -1 for negative class, and 1 for positive class
    label_val_[0] = -1;
    label_val_[1] = 1;
    // weight for label
    label_weights_[0] = 1.0f;
    label_weights_[1] = 1.0f;
    // if using unbalance, change the labels weight
    if (is_unbalance_) {
      label_weights_[1] = 1.0f / cnt_positive;
      label_weights_[0] = 1.0f / cnt_negative;
    }
  }

  void GetGradients(const score_t* score, score_t* gradients, score_t* hessians) const override {
    if (weights_ == nullptr) {
      #pragma omp parallel for schedule(static)
      for (data_size_t i = 0; i < num_data_; ++i) {
        // get label and label weights
        const int label = label_val_[static_cast<int>(label_[i])];
        const score_t label_weight = label_weights_[static_cast<int>(label_[i])];
        // calculate gradients and hessians
        const score_t response = -2.0f * label * sigmoid_ / (1.0f + std::exp(2.0f * label * sigmoid_ * score[i]));
        const score_t abs_response = fabs(response);
        gradients[i] = response * label_weight;
        hessians[i] = abs_response * (2.0f * sigmoid_ - abs_response) * label_weight;
      }
    } else {
      #pragma omp parallel for schedule(static)
      for (data_size_t i = 0; i < num_data_; ++i) {
        // get label and label weights
        const int label = label_val_[static_cast<int>(label_[i])];
        const score_t label_weight = label_weights_[static_cast<int>(label_[i])];
        // calculate gradients and hessians
        const score_t response = -2.0f * label * sigmoid_ / (1.0f + std::exp(2.0f * label * sigmoid_ * score[i]));
        const score_t abs_response = fabs(response);
        gradients[i] = response * label_weight  * weights_[i];
        hessians[i] = abs_response * (2.0f * sigmoid_ - abs_response) * label_weight * weights_[i];
      }
    }
  }

  double GetSigmoid() const override {
    return sigmoid_;
  }

private:
  /*! \brief Number of data */
  data_size_t num_data_;
  /*! \brief Pointer of label */
  const float* label_;
  /*! \brief True if using unbalance training */
  bool is_unbalance_;
  /*! \brief Sigmoid parameter */
  score_t sigmoid_;
  /*! \brief Values for positive and negative labels */
  int label_val_[2];
  /*! \brief Weights for positive and negative labels */
  score_t label_weights_[2];
  /*! \brief Weights for data */
  const float* weights_;
};

}  // namespace LightGBM
#endif  #endif  // LightGBM_OBJECTIVE_BINARY_OBJECTIVE_HPP_