binary_objective.hpp 7.36 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
7
#ifndef LIGHTGBM_OBJECTIVE_BINARY_OBJECTIVE_HPP_
#define LIGHTGBM_OBJECTIVE_BINARY_OBJECTIVE_HPP_

8
9
10
#include <LightGBM/network.h>
#include <LightGBM/objective_function.h>

11
12
#include <string>
#include <algorithm>
Guolin Ke's avatar
Guolin Ke committed
13
#include <cmath>
14
15
#include <cstring>
#include <vector>
Guolin Ke's avatar
Guolin Ke committed
16
17
18

namespace LightGBM {
/*!
19
* \brief Objective function for binary classification
Guolin Ke's avatar
Guolin Ke committed
20
21
*/
class BinaryLogloss: public ObjectiveFunction {
Nikita Titov's avatar
Nikita Titov committed
22
 public:
Guolin Ke's avatar
Guolin Ke committed
23
24
25
  explicit BinaryLogloss(const Config& config,
                         std::function<bool(label_t)> is_pos = nullptr)
      : deterministic_(config.deterministic) {
26
    sigmoid_ = static_cast<double>(config.sigmoid);
Guolin Ke's avatar
Guolin Ke committed
27
    if (sigmoid_ <= 0.0) {
28
      Log::Fatal("Sigmoid parameter %f should be greater than zero", sigmoid_);
Guolin Ke's avatar
Guolin Ke committed
29
    }
30
    is_unbalance_ = config.is_unbalance;
31
    scale_pos_weight_ = static_cast<double>(config.scale_pos_weight);
32
    if (is_unbalance_ && std::fabs(scale_pos_weight_ - 1.0f) > 1e-6) {
33
      Log::Fatal("Cannot set is_unbalance and scale_pos_weight at the same time");
34
    }
Guolin Ke's avatar
Guolin Ke committed
35
36
    is_pos_ = is_pos;
    if (is_pos_ == nullptr) {
37
      is_pos_ = [](label_t label) { return label > 0; };
Guolin Ke's avatar
Guolin Ke committed
38
    }
Guolin Ke's avatar
Guolin Ke committed
39
  }
Guolin Ke's avatar
Guolin Ke committed
40

Guolin Ke's avatar
Guolin Ke committed
41
42
  explicit BinaryLogloss(const std::vector<std::string>& strs)
      : deterministic_(false) {
43
44
    sigmoid_ = -1;
    for (auto str : strs) {
Guolin Ke's avatar
Guolin Ke committed
45
      auto tokens = Common::Split(str.c_str(), ':');
46
47
48
49
50
51
52
53
54
55
56
      if (tokens.size() == 2) {
        if (tokens[0] == std::string("sigmoid")) {
          Common::Atof(tokens[1].c_str(), &sigmoid_);
        }
      }
    }
    if (sigmoid_ <= 0.0) {
      Log::Fatal("Sigmoid parameter %f should be greater than zero", sigmoid_);
    }
  }

Guolin Ke's avatar
Guolin Ke committed
57
  ~BinaryLogloss() {}
Guolin Ke's avatar
Guolin Ke committed
58

Guolin Ke's avatar
Guolin Ke committed
59
60
61
62
63
64
65
  void Init(const Metadata& metadata, data_size_t num_data) override {
    num_data_ = num_data;
    label_ = metadata.label();
    weights_ = metadata.weights();
    data_size_t cnt_positive = 0;
    data_size_t cnt_negative = 0;
    // count for positive and negative samples
66
    #pragma omp parallel for schedule(static) reduction(+:cnt_positive, cnt_negative)
Guolin Ke's avatar
Guolin Ke committed
67
    for (data_size_t i = 0; i < num_data_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
68
      if (is_pos_(label_[i])) {
Guolin Ke's avatar
Guolin Ke committed
69
70
71
72
73
        ++cnt_positive;
      } else {
        ++cnt_negative;
      }
    }
74
75
76
77
78
    num_pos_data_ = cnt_positive;
    if (Network::num_machines() > 1) {
      cnt_positive = Network::GlobalSyncUpBySum(cnt_positive);
      cnt_negative = Network::GlobalSyncUpBySum(cnt_negative);
    }
79
    need_train_ = true;
80
    if (cnt_negative == 0 || cnt_positive == 0) {
81
      Log::Warning("Contains only one class");
82
      // not need to boost.
83
      need_train_ = false;
84
    }
ProtD's avatar
ProtD committed
85
    Log::Info("Number of positive: %d, number of negative: %d", cnt_positive, cnt_negative);
Guolin Ke's avatar
Guolin Ke committed
86
87
88
89
90
91
92
    // use -1 for negative class, and 1 for positive class
    label_val_[0] = -1;
    label_val_[1] = 1;
    // weight for label
    label_weights_[0] = 1.0f;
    label_weights_[1] = 1.0f;
    // if using unbalance, change the labels weight
93
    if (is_unbalance_ && cnt_positive > 0 && cnt_negative > 0) {
94
95
      if (cnt_positive > cnt_negative) {
        label_weights_[1] = 1.0f;
96
        label_weights_[0] = static_cast<double>(cnt_positive) / cnt_negative;
97
      } else {
98
        label_weights_[1] = static_cast<double>(cnt_negative) / cnt_positive;
99
100
        label_weights_[0] = 1.0f;
      }
Guolin Ke's avatar
Guolin Ke committed
101
    }
Guolin Ke's avatar
Guolin Ke committed
102
    label_weights_[1] *= scale_pos_weight_;
Guolin Ke's avatar
Guolin Ke committed
103
104
  }

105
  void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override {
106
107
108
    if (!need_train_) {
      return;
    }
Guolin Ke's avatar
Guolin Ke committed
109
110
111
112
    if (weights_ == nullptr) {
      #pragma omp parallel for schedule(static)
      for (data_size_t i = 0; i < num_data_; ++i) {
        // get label and label weights
Guolin Ke's avatar
Guolin Ke committed
113
        const int is_pos = is_pos_(label_[i]);
Guolin Ke's avatar
Guolin Ke committed
114
115
        const int label = label_val_[is_pos];
        const double label_weight = label_weights_[is_pos];
Guolin Ke's avatar
Guolin Ke committed
116
        // calculate gradients and hessians
117
        const double response = -label * sigmoid_ / (1.0f + std::exp(label * sigmoid_ * score[i]));
118
        const double abs_response = fabs(response);
119
120
        gradients[i] = static_cast<score_t>(response * label_weight);
        hessians[i] = static_cast<score_t>(abs_response * (sigmoid_ - abs_response) * label_weight);
Guolin Ke's avatar
Guolin Ke committed
121
122
123
124
125
      }
    } else {
      #pragma omp parallel for schedule(static)
      for (data_size_t i = 0; i < num_data_; ++i) {
        // get label and label weights
Guolin Ke's avatar
Guolin Ke committed
126
        const int is_pos = is_pos_(label_[i]);
Guolin Ke's avatar
Guolin Ke committed
127
128
        const int label = label_val_[is_pos];
        const double label_weight = label_weights_[is_pos];
Guolin Ke's avatar
Guolin Ke committed
129
        // calculate gradients and hessians
130
        const double response = -label * sigmoid_ / (1.0f + std::exp(label * sigmoid_ * score[i]));
131
        const double abs_response = fabs(response);
132
133
        gradients[i] = static_cast<score_t>(response * label_weight  * weights_[i]);
        hessians[i] = static_cast<score_t>(abs_response * (sigmoid_ - abs_response) * label_weight * weights_[i]);
Guolin Ke's avatar
Guolin Ke committed
134
135
136
      }
    }
  }
137

138
  // implement custom average to boost from (if enabled among options)
139
  double BoostFromScore(int) const override {
140
141
142
    double suml = 0.0f;
    double sumw = 0.0f;
    if (weights_ != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
143
      #pragma omp parallel for schedule(static) reduction(+:suml, sumw) if (!deterministic_)
144
      for (data_size_t i = 0; i < num_data_; ++i) {
145
        suml += is_pos_(label_[i]) * weights_[i];
146
147
148
149
        sumw += weights_[i];
      }
    } else {
      sumw = static_cast<double>(num_data_);
Guolin Ke's avatar
Guolin Ke committed
150
      #pragma omp parallel for schedule(static) reduction(+:suml) if (!deterministic_)
151
      for (data_size_t i = 0; i < num_data_; ++i) {
152
        suml += is_pos_(label_[i]);
153
154
      }
    }
155
156
157
158
    if (Network::num_machines() > 1) {
      suml = Network::GlobalSyncUpBySum(suml);
      sumw = Network::GlobalSyncUpBySum(sumw);
    }
159
    double pavg = suml / sumw;
160
161
    pavg = std::min(pavg, 1.0 - kEpsilon);
    pavg = std::max<double>(pavg, kEpsilon);
162
163
164
165
    double initscore = std::log(pavg / (1.0f - pavg)) / sigmoid_;
    Log::Info("[%s:%s]: pavg=%f -> initscore=%f",  GetName(), __func__, pavg, initscore);
    return initscore;
  }
Guolin Ke's avatar
Guolin Ke committed
166

167
  bool ClassNeedTrain(int /*class_id*/) const override {
168
    return need_train_;
169
170
  }

Guolin Ke's avatar
Guolin Ke committed
171
172
  const char* GetName() const override {
    return "binary";
Guolin Ke's avatar
Guolin Ke committed
173
174
  }

Guolin Ke's avatar
Guolin Ke committed
175
176
  void ConvertOutput(const double* input, double* output) const override {
    output[0] = 1.0f / (1.0f + std::exp(-sigmoid_ * input[0]));
177
178
179
180
181
182
183
184
185
186
187
  }

  std::string ToString() const override {
    std::stringstream str_buf;
    str_buf << GetName() << " ";
    str_buf << "sigmoid:" << sigmoid_;
    return str_buf.str();
  }

  bool SkipEmptyClass() const override { return true; }

188
189
  bool NeedAccuratePrediction() const override { return false; }

Guolin Ke's avatar
Guolin Ke committed
190
191
  data_size_t NumPositiveData() const override { return num_pos_data_; }

192
 protected:
Guolin Ke's avatar
Guolin Ke committed
193
194
  /*! \brief Number of data */
  data_size_t num_data_;
Guolin Ke's avatar
Guolin Ke committed
195
196
  /*! \brief Number of positive samples */
  data_size_t num_pos_data_;
Guolin Ke's avatar
Guolin Ke committed
197
  /*! \brief Pointer of label */
198
  const label_t* label_;
Guolin Ke's avatar
Guolin Ke committed
199
200
201
  /*! \brief True if using unbalance training */
  bool is_unbalance_;
  /*! \brief Sigmoid parameter */
202
  double sigmoid_;
Guolin Ke's avatar
Guolin Ke committed
203
204
205
  /*! \brief Values for positive and negative labels */
  int label_val_[2];
  /*! \brief Weights for positive and negative labels */
206
  double label_weights_[2];
Guolin Ke's avatar
Guolin Ke committed
207
  /*! \brief Weights for data */
208
  const label_t* weights_;
209
  double scale_pos_weight_;
210
  std::function<bool(label_t)> is_pos_;
211
  bool need_train_;
Guolin Ke's avatar
Guolin Ke committed
212
  const bool deterministic_;
Guolin Ke's avatar
Guolin Ke committed
213
214
215
};

}  // namespace LightGBM
Guolin Ke's avatar
Guolin Ke committed
216
#endif   // LightGBM_OBJECTIVE_BINARY_OBJECTIVE_HPP_