multiclass_objective.hpp 7.46 KB
Newer Older
1
2
3
4
5
6
7
#ifndef LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_
#define LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_

#include <LightGBM/objective_function.h>

#include <cstring>
#include <cmath>
Guolin Ke's avatar
Guolin Ke committed
8
9
10
#include <vector>

#include "binary_objective.hpp"
11
12
13

namespace LightGBM {
/*!
Guolin Ke's avatar
Guolin Ke committed
14
* \brief Objective function for multiclass classification, use softmax as objective functions
15
*/
Guolin Ke's avatar
Guolin Ke committed
16
class MulticlassSoftmax: public ObjectiveFunction {
17
public:
Guolin Ke's avatar
Guolin Ke committed
18
  explicit MulticlassSoftmax(const ObjectiveConfig& config) {
19
    num_class_ = config.num_class;
Guolin Ke's avatar
Guolin Ke committed
20
    softmax_weight_decay_ = 1e-3;
21
  }
22

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
  explicit MulticlassSoftmax(const std::vector<std::string>& strs) {
    num_class_ = -1;
    for (auto str : strs) {
      auto tokens = Common::Split(str.c_str(), ":");
      if (tokens.size() == 2) {
        if (tokens[0] == std::string("num_class")) {
          Common::Atoi(tokens[1].c_str(), &num_class_);
        }
      }
    }
    if (num_class_ < 0) {
      Log::Fatal("Objective should contains num_class field");
    }
  }

Guolin Ke's avatar
Guolin Ke committed
38
39
  ~MulticlassSoftmax() {

40
  }
41

42
43
44
45
  void Init(const Metadata& metadata, data_size_t num_data) override {
    num_data_ = num_data;
    label_ = metadata.label();
    weights_ = metadata.weights();
Guolin Ke's avatar
Guolin Ke committed
46
    label_int_.resize(num_data_);
47
    std::vector<data_size_t> cnt_per_class(num_class_, 0);
48
49
50
51
52
    for (int i = 0; i < num_data_; ++i) {
      label_int_[i] = static_cast<int>(label_[i]);
      if (label_int_[i] < 0 || label_int_[i] >= num_class_) {
        Log::Fatal("Label must be in [0, %d), but found %d in label", num_class_, label_int_[i]);
      }
53
      ++cnt_per_class[label_int_[i]];
54
    }
55
56
57
58
59
60
61
62
63
64
    int non_empty_class = 0;
    is_empty_class_ = std::vector<bool>(num_class_, false);
    for (int i = 0; i < num_class_; ++i) {
      if (cnt_per_class[i] > 0) {
        ++non_empty_class;
      } else {
        is_empty_class_[i] = true;
      }
    }
    if (non_empty_class < 2) { non_empty_class = 2; }
Guolin Ke's avatar
Guolin Ke committed
65
    hessian_nor_ = static_cast<float>(non_empty_class) / (non_empty_class - 1);
66
67
  }

Guolin Ke's avatar
Guolin Ke committed
68
  void GetGradients(const double* score, float* gradients, float* hessians) const override {
69
    if (weights_ == nullptr) {
zhangyafeikimi's avatar
zhangyafeikimi committed
70
71
      std::vector<double> rec;
      #pragma omp parallel for schedule(static) private(rec)
72
      for (data_size_t i = 0; i < num_data_; ++i) {
zhangyafeikimi's avatar
zhangyafeikimi committed
73
        rec.resize(num_class_);
Guolin Ke's avatar
Guolin Ke committed
74
        for (int k = 0; k < num_class_; ++k) {
75
76
          size_t idx = static_cast<size_t>(num_data_) * k + i;
          rec[k] = static_cast<double>(score[idx]);
77
        }
78
        Common::Softmax(&rec);
79
        for (int k = 0; k < num_class_; ++k) {
80
          if (is_empty_class_[k]) { continue; }
81
          auto p = rec[k];
82
          size_t idx = static_cast<size_t>(num_data_) * k + i;
83
          if (label_int_[i] == k) {
Guolin Ke's avatar
Guolin Ke committed
84
            gradients[idx] = static_cast<float>(p - 1.0f + softmax_weight_decay_ * score[idx]);
85
          } else {
Guolin Ke's avatar
Guolin Ke committed
86
            gradients[idx] = static_cast<float>(p + softmax_weight_decay_ * score[idx]);
87
          }
Guolin Ke's avatar
Guolin Ke committed
88
          hessians[idx] = static_cast<float>(hessian_nor_ * p * (1.0f - p) + softmax_weight_decay_);
89
        }
90
91
      }
    } else {
zhangyafeikimi's avatar
zhangyafeikimi committed
92
93
      std::vector<double> rec;
      #pragma omp parallel for schedule(static) private(rec)
94
      for (data_size_t i = 0; i < num_data_; ++i) {
zhangyafeikimi's avatar
zhangyafeikimi committed
95
        rec.resize(num_class_);
Guolin Ke's avatar
Guolin Ke committed
96
        for (int k = 0; k < num_class_; ++k) {
97
98
          size_t idx = static_cast<size_t>(num_data_) * k + i;
          rec[k] = static_cast<double>(score[idx]);
99
        }
100
101
        Common::Softmax(&rec);
        for (int k = 0; k < num_class_; ++k) {
102
          if (is_empty_class_[k]) { continue; }
103
          auto p = rec[k];
104
          size_t idx = static_cast<size_t>(num_data_) * k + i;
105
          if (label_int_[i] == k) {
Guolin Ke's avatar
Guolin Ke committed
106
            gradients[idx] = static_cast<float>((p - 1.0f + softmax_weight_decay_ * score[idx]) * weights_[i]);
107
          } else {
Guolin Ke's avatar
Guolin Ke committed
108
            gradients[idx] = static_cast<float>((p + softmax_weight_decay_ * score[idx]) * weights_[i]);
109
          }
Guolin Ke's avatar
Guolin Ke committed
110
          hessians[idx] = static_cast<float>((hessian_nor_ * p * (1.0f - p) + softmax_weight_decay_)* weights_[i]);
111
112
113
114
115
        }
      }
    }
  }

Guolin Ke's avatar
Guolin Ke committed
116
117
  void ConvertOutput(const double* input, double* output) const override {
    Common::Softmax(input, output, num_class_);
118
119
  }

Guolin Ke's avatar
Guolin Ke committed
120
121
  const char* GetName() const override {
    return "multiclass";
122
123
  }

124
125
126
127
128
129
130
131
132
  std::string ToString() const override {
    std::stringstream str_buf;
    str_buf << GetName() << " ";
    str_buf << "num_class:" << num_class_;
    return str_buf.str();
  }

  bool SkipEmptyClass() const override { return true; }

Guolin Ke's avatar
Guolin Ke committed
133
134
135
  int NumTreePerIteration() const override { return num_class_; }

  int NumPredictOneRow() const override { return num_class_; }
136

137
138
139
140
141
142
143
144
private:
  /*! \brief Number of data */
  data_size_t num_data_;
  /*! \brief Number of classes */
  int num_class_;
  /*! \brief Pointer of label */
  const float* label_;
  /*! \brief Corresponding integers of label_ */
Guolin Ke's avatar
Guolin Ke committed
145
  std::vector<int> label_int_;
146
147
  /*! \brief Weights for data */
  const float* weights_;
148
  std::vector<bool> is_empty_class_;
Guolin Ke's avatar
Guolin Ke committed
149
  double softmax_weight_decay_;
Guolin Ke's avatar
Guolin Ke committed
150
  float hessian_nor_;
Guolin Ke's avatar
Guolin Ke committed
151
152
153
154
155
156
157
158
159
160
161
162
163
};

/*!
* \brief Objective function for multiclass classification, use one-vs-all binary objective function
*/
class MulticlassOVA: public ObjectiveFunction {
public:
  explicit MulticlassOVA(const ObjectiveConfig& config) {
    num_class_ = config.num_class;
    for (int i = 0; i < num_class_; ++i) {
      binary_loss_.emplace_back(
        new BinaryLogloss(config, [i](float label) { return static_cast<int>(label) == i; }));
    }
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
    sigmoid_ = config.sigmoid;
  }

  explicit MulticlassOVA(const std::vector<std::string>& strs) {
    num_class_ = -1;
    sigmoid_ = -1;
    for (auto str : strs) {
      auto tokens = Common::Split(str.c_str(), ":");
      if (tokens.size() == 2) {
        if (tokens[0] == std::string("num_class")) {
          Common::Atoi(tokens[1].c_str(), &num_class_);
        } else if (tokens[0] == std::string("sigmoid")) {
          Common::Atof(tokens[1].c_str(), &sigmoid_);
        }
      }
    }
    if (num_class_ < 0) {
      Log::Fatal("Objective should contains num_class field");
    }
    if (sigmoid_ <= 0.0) {
      Log::Fatal("Sigmoid parameter %f should be greater than zero", sigmoid_);
    }
Guolin Ke's avatar
Guolin Ke committed
186
187
188
189
190
191
192
193
194
195
196
197
198
  }

  ~MulticlassOVA() {

  }

  void Init(const Metadata& metadata, data_size_t num_data) override {
    num_data_ = num_data;
    for (int i = 0; i < num_class_; ++i) {
      binary_loss_[i]->Init(metadata, num_data);
    }
  }

Guolin Ke's avatar
Guolin Ke committed
199
  void GetGradients(const double* score, float* gradients, float* hessians) const override {
Guolin Ke's avatar
Guolin Ke committed
200
    for (int i = 0; i < num_class_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
201
      size_t bias = static_cast<size_t>(num_data_) * i;
Guolin Ke's avatar
Guolin Ke committed
202
203
204
205
206
207
208
209
      binary_loss_[i]->GetGradients(score + bias, gradients + bias, hessians + bias);
    }
  }

  const char* GetName() const override {
    return "multiclassova";
  }

Guolin Ke's avatar
Guolin Ke committed
210
  void ConvertOutput(const double* input, double* output) const override {
211
    for (int i = 0; i < num_class_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
212
      output[i] = 1.0f / (1.0f + std::exp(-sigmoid_ * input[i]));
213
214
215
216
217
218
219
220
221
222
223
224
225
    }
  }

  std::string ToString() const override {
    std::stringstream str_buf;
    str_buf << GetName() << " ";
    str_buf << "num_class:" << num_class_ << " ";
    str_buf << "sigmoid:" << sigmoid_;
    return str_buf.str();
  }

  bool SkipEmptyClass() const override { return true; }

Guolin Ke's avatar
Guolin Ke committed
226
227
228
  int NumTreePerIteration() const override { return num_class_; }

  int NumPredictOneRow() const override { return num_class_; }
229

Guolin Ke's avatar
Guolin Ke committed
230
231
232
233
234
235
private:
  /*! \brief Number of data */
  data_size_t num_data_;
  /*! \brief Number of classes */
  int num_class_;
  std::vector<std::unique_ptr<BinaryLogloss>> binary_loss_;
236
  double sigmoid_;
237
238
239
240
};

}  // namespace LightGBM
#endif   // LightGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_