multiclass_objective.hpp 4.41 KB
Newer Older
1
2
3
4
5
6
7
#ifndef LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_
#define LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_

#include <LightGBM/objective_function.h>

#include <cstring>
#include <cmath>
Guolin Ke's avatar
Guolin Ke committed
8
9
10
#include <vector>

#include "binary_objective.hpp"
11
12
13

namespace LightGBM {
/*!
Guolin Ke's avatar
Guolin Ke committed
14
* \brief Objective function for multiclass classification, use softmax as objective functions
15
*/
Guolin Ke's avatar
Guolin Ke committed
16
class MulticlassSoftmax: public ObjectiveFunction {
17
public:
Guolin Ke's avatar
Guolin Ke committed
18
  explicit MulticlassSoftmax(const ObjectiveConfig& config) {
19
20
    num_class_ = config.num_class;
  }
21

Guolin Ke's avatar
Guolin Ke committed
22
23
  ~MulticlassSoftmax() {

24
  }
25

26
27
28
29
  void Init(const Metadata& metadata, data_size_t num_data) override {
    num_data_ = num_data;
    label_ = metadata.label();
    weights_ = metadata.weights();
Guolin Ke's avatar
Guolin Ke committed
30
    label_int_.resize(num_data_);
31
32
33
34
35
36
    #pragma omp parallel for schedule(static)
    for (int i = 0; i < num_data_; ++i) {
      label_int_[i] = static_cast<int>(label_[i]);
      if (label_int_[i] < 0 || label_int_[i] >= num_class_) {
        Log::Fatal("Label must be in [0, %d), but found %d in label", num_class_, label_int_[i]);
      }
37
38
39
    }
  }

40
  void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override {
41
    if (weights_ == nullptr) {
zhangyafeikimi's avatar
zhangyafeikimi committed
42
43
      std::vector<double> rec;
      #pragma omp parallel for schedule(static) private(rec)
44
      for (data_size_t i = 0; i < num_data_; ++i) {
zhangyafeikimi's avatar
zhangyafeikimi committed
45
        rec.resize(num_class_);
Guolin Ke's avatar
Guolin Ke committed
46
        for (int k = 0; k < num_class_; ++k) {
47
48
          size_t idx = static_cast<size_t>(num_data_) * k + i;
          rec[k] = static_cast<double>(score[idx]);
49
        }
50
        Common::Softmax(&rec);
51
        for (int k = 0; k < num_class_; ++k) {
52
          auto p = rec[k];
53
          size_t idx = static_cast<size_t>(num_data_) * k + i;
54
          if (label_int_[i] == k) {
Guolin Ke's avatar
Guolin Ke committed
55
            gradients[idx] = static_cast<score_t>(p - 1.0f);
56
          } else {
57
            gradients[idx] = static_cast<score_t>(p);
58
          }
Guolin Ke's avatar
Guolin Ke committed
59
          hessians[idx] = static_cast<score_t>(2.0f * p * (1.0f - p));
60
        }
61
62
      }
    } else {
zhangyafeikimi's avatar
zhangyafeikimi committed
63
64
      std::vector<double> rec;
      #pragma omp parallel for schedule(static) private(rec)
65
      for (data_size_t i = 0; i < num_data_; ++i) {
zhangyafeikimi's avatar
zhangyafeikimi committed
66
        rec.resize(num_class_);
Guolin Ke's avatar
Guolin Ke committed
67
        for (int k = 0; k < num_class_; ++k) {
68
69
          size_t idx = static_cast<size_t>(num_data_) * k + i;
          rec[k] = static_cast<double>(score[idx]);
70
        }
71
72
        Common::Softmax(&rec);
        for (int k = 0; k < num_class_; ++k) {
73
          auto p = rec[k];
74
          size_t idx = static_cast<size_t>(num_data_) * k + i;
75
          if (label_int_[i] == k) {
Guolin Ke's avatar
Guolin Ke committed
76
            gradients[idx] = static_cast<score_t>((p - 1.0f) * weights_[i]);
77
          } else {
78
            gradients[idx] = static_cast<score_t>(p * weights_[i]);
79
          }
Guolin Ke's avatar
Guolin Ke committed
80
          hessians[idx] = static_cast<score_t>(2.0f * p * (1.0f - p) * weights_[i]);
81
82
83
84
85
        }
      }
    }
  }

Guolin Ke's avatar
Guolin Ke committed
86
87
  const char* GetName() const override {
    return "multiclass";
88
89
90
91
92
93
94
95
96
97
  }

private:
  /*! \brief Number of data */
  data_size_t num_data_;
  /*! \brief Number of classes */
  int num_class_;
  /*! \brief Pointer of label */
  const float* label_;
  /*! \brief Corresponding integers of label_ */
Guolin Ke's avatar
Guolin Ke committed
98
  std::vector<int> label_int_;
99
100
  /*! \brief Weights for data */
  const float* weights_;
Guolin Ke's avatar
Guolin Ke committed
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
};

/*!
* \brief Objective function for multiclass classification, use one-vs-all binary objective function
*/
class MulticlassOVA: public ObjectiveFunction {
public:
  explicit MulticlassOVA(const ObjectiveConfig& config) {
    num_class_ = config.num_class;
    for (int i = 0; i < num_class_; ++i) {
      binary_loss_.emplace_back(
        new BinaryLogloss(config, [i](float label) { return static_cast<int>(label) == i; }));
    }
  }

  ~MulticlassOVA() {

  }

  void Init(const Metadata& metadata, data_size_t num_data) override {
    num_data_ = num_data;
    for (int i = 0; i < num_class_; ++i) {
      binary_loss_[i]->Init(metadata, num_data);
    }
  }

  void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override {
    for (int i = 0; i < num_class_; ++i) {
      int64_t bias = static_cast<int64_t>(num_data_) * i;
      binary_loss_[i]->GetGradients(score + bias, gradients + bias, hessians + bias);
    }
  }

  const char* GetName() const override {
    return "multiclassova";
  }

private:
  /*! \brief Number of data */
  data_size_t num_data_;
  /*! \brief Number of classes */
  int num_class_;
  std::vector<std::unique_ptr<BinaryLogloss>> binary_loss_;
144
145
146
147
};

}  // namespace LightGBM
#endif   // LightGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_