multiclass_objective.hpp 8.77 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
5
6
7
#ifndef LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_
#define LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_

8
9
10
#include <LightGBM/network.h>
#include <LightGBM/objective_function.h>

11
12
#include <string>
#include <algorithm>
13
#include <cmath>
14
15
#include <cstring>
#include <memory>
Guolin Ke's avatar
Guolin Ke committed
16
17
18
#include <vector>

#include "binary_objective.hpp"
19
20
21

namespace LightGBM {
/*!
Guolin Ke's avatar
Guolin Ke committed
22
* \brief Objective function for multiclass classification, use softmax as objective functions
23
*/
Guolin Ke's avatar
Guolin Ke committed
24
class MulticlassSoftmax: public ObjectiveFunction {
Nikita Titov's avatar
Nikita Titov committed
25
 public:
Guolin Ke's avatar
Guolin Ke committed
26
  explicit MulticlassSoftmax(const Config& config) {
27
    num_class_ = config.num_class;
28
29
30
31
    // This factor is to rescale the redundant form of K-classification, to the non-redundant form.
    // In the traditional settings of K-classification, there is one redundant class, whose output is set to 0 (like the class 0 in binary classification).
    // This is from the Friedman GBDT paper.
    factor_ = static_cast<double>(num_class_) / (num_class_ - 1.0f);
32
  }
33

34
35
36
  explicit MulticlassSoftmax(const std::vector<std::string>& strs) {
    num_class_ = -1;
    for (auto str : strs) {
Guolin Ke's avatar
Guolin Ke committed
37
      auto tokens = Common::Split(str.c_str(), ':');
38
39
40
41
42
43
44
      if (tokens.size() == 2) {
        if (tokens[0] == std::string("num_class")) {
          Common::Atoi(tokens[1].c_str(), &num_class_);
        }
      }
    }
    if (num_class_ < 0) {
45
      Log::Fatal("Objective should contain num_class field");
46
    }
47
    factor_ = static_cast<double>(num_class_) / (num_class_ - 1.0f);
48
49
  }

Guolin Ke's avatar
Guolin Ke committed
50
  ~MulticlassSoftmax() {
51
  }
52

53
54
55
56
  void Init(const Metadata& metadata, data_size_t num_data) override {
    num_data_ = num_data;
    label_ = metadata.label();
    weights_ = metadata.weights();
Guolin Ke's avatar
Guolin Ke committed
57
    label_int_.resize(num_data_);
58
59
    class_init_probs_.resize(num_class_, 0.0);
    double sum_weight = 0.0;
60
61
62
63
64
    for (int i = 0; i < num_data_; ++i) {
      label_int_[i] = static_cast<int>(label_[i]);
      if (label_int_[i] < 0 || label_int_[i] >= num_class_) {
        Log::Fatal("Label must be in [0, %d), but found %d in label", num_class_, label_int_[i]);
      }
65
66
67
68
69
70
71
72
73
74
      if (weights_ == nullptr) {
        class_init_probs_[label_int_[i]] += 1.0;
      } else {
        class_init_probs_[label_int_[i]] += weights_[i];
        sum_weight += weights_[i];
      }
    }
    if (weights_ == nullptr) {
      sum_weight = num_data_;
    }
75
76
77
78
79
80
    if (Network::num_machines() > 1) {
      sum_weight = Network::GlobalSyncUpBySum(sum_weight);
      for (int i = 0; i < num_class_; ++i) {
        class_init_probs_[i] = Network::GlobalSyncUpBySum(class_init_probs_[i]);
      }
    }
81
82
    for (int i = 0; i < num_class_; ++i) {
      class_init_probs_[i] /= sum_weight;
83
84
85
    }
  }

86
  void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override {
87
    if (weights_ == nullptr) {
zhangyafeikimi's avatar
zhangyafeikimi committed
88
      std::vector<double> rec;
89
      #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) private(rec)
90
      for (data_size_t i = 0; i < num_data_; ++i) {
zhangyafeikimi's avatar
zhangyafeikimi committed
91
        rec.resize(num_class_);
Guolin Ke's avatar
Guolin Ke committed
92
        for (int k = 0; k < num_class_; ++k) {
93
94
          size_t idx = static_cast<size_t>(num_data_) * k + i;
          rec[k] = static_cast<double>(score[idx]);
95
        }
96
        Common::Softmax(&rec);
97
        for (int k = 0; k < num_class_; ++k) {
98
          auto p = rec[k];
99
          size_t idx = static_cast<size_t>(num_data_) * k + i;
100
          if (label_int_[i] == k) {
Guolin Ke's avatar
Guolin Ke committed
101
            gradients[idx] = static_cast<score_t>(p - 1.0f);
102
          } else {
Guolin Ke's avatar
Guolin Ke committed
103
            gradients[idx] = static_cast<score_t>(p);
104
          }
105
          hessians[idx] = static_cast<score_t>(factor_ * p * (1.0f - p));
106
        }
107
108
      }
    } else {
zhangyafeikimi's avatar
zhangyafeikimi committed
109
      std::vector<double> rec;
110
      #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) private(rec)
111
      for (data_size_t i = 0; i < num_data_; ++i) {
zhangyafeikimi's avatar
zhangyafeikimi committed
112
        rec.resize(num_class_);
Guolin Ke's avatar
Guolin Ke committed
113
        for (int k = 0; k < num_class_; ++k) {
114
115
          size_t idx = static_cast<size_t>(num_data_) * k + i;
          rec[k] = static_cast<double>(score[idx]);
116
        }
117
118
        Common::Softmax(&rec);
        for (int k = 0; k < num_class_; ++k) {
119
          auto p = rec[k];
120
          size_t idx = static_cast<size_t>(num_data_) * k + i;
121
          if (label_int_[i] == k) {
Guolin Ke's avatar
Guolin Ke committed
122
            gradients[idx] = static_cast<score_t>((p - 1.0f) * weights_[i]);
123
          } else {
124
            gradients[idx] = static_cast<score_t>(p * weights_[i]);
125
          }
126
          hessians[idx] = static_cast<score_t>((factor_ * p * (1.0f - p))* weights_[i]);
127
128
129
130
131
        }
      }
    }
  }

Guolin Ke's avatar
Guolin Ke committed
132
133
  void ConvertOutput(const double* input, double* output) const override {
    Common::Softmax(input, output, num_class_);
134
135
  }

Guolin Ke's avatar
Guolin Ke committed
136
137
  const char* GetName() const override {
    return "multiclass";
138
139
  }

140
141
142
143
144
145
146
147
148
  std::string ToString() const override {
    std::stringstream str_buf;
    str_buf << GetName() << " ";
    str_buf << "num_class:" << num_class_;
    return str_buf.str();
  }

  bool SkipEmptyClass() const override { return true; }

Guolin Ke's avatar
Guolin Ke committed
149
  int NumModelPerIteration() const override { return num_class_; }
Guolin Ke's avatar
Guolin Ke committed
150
151

  int NumPredictOneRow() const override { return num_class_; }
152

153
154
  bool NeedAccuratePrediction() const override { return false; }

155
156
157
158
  double BoostFromScore(int class_id) const override {
    return std::log(std::max<double>(kEpsilon, class_init_probs_[class_id]));
  }

159
160
  bool ClassNeedTrain(int class_id) const override {
    if (std::fabs(class_init_probs_[class_id]) <= kEpsilon
161
162
163
164
165
166
167
        || std::fabs(class_init_probs_[class_id]) >= 1.0 - kEpsilon) {
      return false;
    } else {
      return true;
    }
  }

168
 protected:
169
  double factor_;
170
171
172
173
174
  /*! \brief Number of data */
  data_size_t num_data_;
  /*! \brief Number of classes */
  int num_class_;
  /*! \brief Pointer of label */
175
  const label_t* label_;
176
  /*! \brief Corresponding integers of label_ */
Guolin Ke's avatar
Guolin Ke committed
177
  std::vector<int> label_int_;
178
  /*! \brief Weights for data */
179
  const label_t* weights_;
180
  std::vector<double> class_init_probs_;
Guolin Ke's avatar
Guolin Ke committed
181
182
183
184
185
186
};

/*!
* \brief Objective function for multiclass classification, use one-vs-all binary objective function
*/
class MulticlassOVA: public ObjectiveFunction {
Nikita Titov's avatar
Nikita Titov committed
187
 public:
Guolin Ke's avatar
Guolin Ke committed
188
  explicit MulticlassOVA(const Config& config) {
Guolin Ke's avatar
Guolin Ke committed
189
190
191
    num_class_ = config.num_class;
    for (int i = 0; i < num_class_; ++i) {
      binary_loss_.emplace_back(
192
        new BinaryLogloss(config, [i](label_t label) { return static_cast<int>(label) == i; }));
Guolin Ke's avatar
Guolin Ke committed
193
    }
194
195
196
197
198
199
200
    sigmoid_ = config.sigmoid;
  }

  explicit MulticlassOVA(const std::vector<std::string>& strs) {
    num_class_ = -1;
    sigmoid_ = -1;
    for (auto str : strs) {
Guolin Ke's avatar
Guolin Ke committed
201
      auto tokens = Common::Split(str.c_str(), ':');
202
203
204
205
206
207
208
209
210
      if (tokens.size() == 2) {
        if (tokens[0] == std::string("num_class")) {
          Common::Atoi(tokens[1].c_str(), &num_class_);
        } else if (tokens[0] == std::string("sigmoid")) {
          Common::Atof(tokens[1].c_str(), &sigmoid_);
        }
      }
    }
    if (num_class_ < 0) {
211
      Log::Fatal("Objective should contain num_class field");
212
213
214
215
    }
    if (sigmoid_ <= 0.0) {
      Log::Fatal("Sigmoid parameter %f should be greater than zero", sigmoid_);
    }
Guolin Ke's avatar
Guolin Ke committed
216
217
218
219
220
221
222
223
224
225
226
227
  }

  ~MulticlassOVA() {
  }

  void Init(const Metadata& metadata, data_size_t num_data) override {
    num_data_ = num_data;
    for (int i = 0; i < num_class_; ++i) {
      binary_loss_[i]->Init(metadata, num_data);
    }
  }

228
  void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override {
Guolin Ke's avatar
Guolin Ke committed
229
    for (int i = 0; i < num_class_; ++i) {
230
231
      int64_t offset = static_cast<int64_t>(num_data_) * i;
      binary_loss_[i]->GetGradients(score + offset, gradients + offset, hessians + offset);
Guolin Ke's avatar
Guolin Ke committed
232
233
234
235
236
237
238
    }
  }

  const char* GetName() const override {
    return "multiclassova";
  }

Guolin Ke's avatar
Guolin Ke committed
239
  void ConvertOutput(const double* input, double* output) const override {
240
    for (int i = 0; i < num_class_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
241
      output[i] = 1.0f / (1.0f + std::exp(-sigmoid_ * input[i]));
242
243
244
245
246
247
248
249
250
251
252
253
254
    }
  }

  std::string ToString() const override {
    std::stringstream str_buf;
    str_buf << GetName() << " ";
    str_buf << "num_class:" << num_class_ << " ";
    str_buf << "sigmoid:" << sigmoid_;
    return str_buf.str();
  }

  bool SkipEmptyClass() const override { return true; }

Guolin Ke's avatar
Guolin Ke committed
255
  int NumModelPerIteration() const override { return num_class_; }
Guolin Ke's avatar
Guolin Ke committed
256
257

  int NumPredictOneRow() const override { return num_class_; }
258

259
260
  bool NeedAccuratePrediction() const override { return false; }

261
262
263
264
265
266
267
268
  double BoostFromScore(int class_id) const override {
    return binary_loss_[class_id]->BoostFromScore(0);
  }

  bool ClassNeedTrain(int class_id) const override {
    return binary_loss_[class_id]->ClassNeedTrain(0);
  }

269
 protected:
Guolin Ke's avatar
Guolin Ke committed
270
271
272
273
274
  /*! \brief Number of data */
  data_size_t num_data_;
  /*! \brief Number of classes */
  int num_class_;
  std::vector<std::unique_ptr<BinaryLogloss>> binary_loss_;
275
  double sigmoid_;
276
277
278
279
};

}  // namespace LightGBM
#endif   // LightGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_