multiclass_objective.hpp 4.78 KB
Newer Older
1
2
3
4
5
6
7
#ifndef LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_
#define LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_

#include <LightGBM/objective_function.h>

#include <cstring>
#include <cmath>
Guolin Ke's avatar
Guolin Ke committed
8
9
10
#include <vector>

#include "binary_objective.hpp"
11
12
13

namespace LightGBM {
/*!
Guolin Ke's avatar
Guolin Ke committed
14
* \brief Objective function for multiclass classification, use softmax as objective functions
15
*/
Guolin Ke's avatar
Guolin Ke committed
16
class MulticlassSoftmax: public ObjectiveFunction {
17
public:
Guolin Ke's avatar
Guolin Ke committed
18
  explicit MulticlassSoftmax(const ObjectiveConfig& config) {
19
    num_class_ = config.num_class;
Guolin Ke's avatar
Guolin Ke committed
20
    softmax_weight_decay_ = 1e-3;
21
  }
22

Guolin Ke's avatar
Guolin Ke committed
23
24
  ~MulticlassSoftmax() {

25
  }
26

27
28
29
30
  void Init(const Metadata& metadata, data_size_t num_data) override {
    num_data_ = num_data;
    label_ = metadata.label();
    weights_ = metadata.weights();
Guolin Ke's avatar
Guolin Ke committed
31
    label_int_.resize(num_data_);
32
33
34
35
36
37
    #pragma omp parallel for schedule(static)
    for (int i = 0; i < num_data_; ++i) {
      label_int_[i] = static_cast<int>(label_[i]);
      if (label_int_[i] < 0 || label_int_[i] >= num_class_) {
        Log::Fatal("Label must be in [0, %d), but found %d in label", num_class_, label_int_[i]);
      }
38
    }
Guolin Ke's avatar
Guolin Ke committed
39
    hessian_nor_ = static_cast<score_t>(num_class_) / (num_class_ - 1);
40
41
  }

42
  void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override {
43
    if (weights_ == nullptr) {
zhangyafeikimi's avatar
zhangyafeikimi committed
44
45
      std::vector<double> rec;
      #pragma omp parallel for schedule(static) private(rec)
46
      for (data_size_t i = 0; i < num_data_; ++i) {
zhangyafeikimi's avatar
zhangyafeikimi committed
47
        rec.resize(num_class_);
Guolin Ke's avatar
Guolin Ke committed
48
        for (int k = 0; k < num_class_; ++k) {
49
50
          size_t idx = static_cast<size_t>(num_data_) * k + i;
          rec[k] = static_cast<double>(score[idx]);
51
        }
52
        Common::Softmax(&rec);
53
        for (int k = 0; k < num_class_; ++k) {
54
          auto p = rec[k];
55
          size_t idx = static_cast<size_t>(num_data_) * k + i;
56
          if (label_int_[i] == k) {
Guolin Ke's avatar
Guolin Ke committed
57
            gradients[idx] = static_cast<score_t>(p - 1.0f + softmax_weight_decay_ * score[idx]);
58
          } else {
Guolin Ke's avatar
Guolin Ke committed
59
            gradients[idx] = static_cast<score_t>(p + softmax_weight_decay_ * score[idx]);
60
          }
Guolin Ke's avatar
Guolin Ke committed
61
          hessians[idx] = static_cast<score_t>(hessian_nor_ * p * (1.0f - p) + softmax_weight_decay_);
62
        }
63
64
      }
    } else {
zhangyafeikimi's avatar
zhangyafeikimi committed
65
66
      std::vector<double> rec;
      #pragma omp parallel for schedule(static) private(rec)
67
      for (data_size_t i = 0; i < num_data_; ++i) {
zhangyafeikimi's avatar
zhangyafeikimi committed
68
        rec.resize(num_class_);
Guolin Ke's avatar
Guolin Ke committed
69
        for (int k = 0; k < num_class_; ++k) {
70
71
          size_t idx = static_cast<size_t>(num_data_) * k + i;
          rec[k] = static_cast<double>(score[idx]);
72
        }
73
74
        Common::Softmax(&rec);
        for (int k = 0; k < num_class_; ++k) {
75
          auto p = rec[k];
76
          size_t idx = static_cast<size_t>(num_data_) * k + i;
77
          if (label_int_[i] == k) {
Guolin Ke's avatar
Guolin Ke committed
78
            gradients[idx] = static_cast<score_t>((p - 1.0f + softmax_weight_decay_ * score[idx]) * weights_[i]);
79
          } else {
Guolin Ke's avatar
Guolin Ke committed
80
            gradients[idx] = static_cast<score_t>((p + softmax_weight_decay_ * score[idx]) * weights_[i]);
81
          }
Guolin Ke's avatar
Guolin Ke committed
82
          hessians[idx] = static_cast<score_t>((hessian_nor_ * p * (1.0f - p) + softmax_weight_decay_)* weights_[i]);
83
84
85
86
87
        }
      }
    }
  }

Guolin Ke's avatar
Guolin Ke committed
88
89
  const char* GetName() const override {
    return "multiclass";
90
91
92
93
94
95
96
97
98
99
  }

private:
  /*! \brief Number of data */
  data_size_t num_data_;
  /*! \brief Number of classes */
  int num_class_;
  /*! \brief Pointer of label */
  const float* label_;
  /*! \brief Corresponding integers of label_ */
Guolin Ke's avatar
Guolin Ke committed
100
  std::vector<int> label_int_;
101
102
  /*! \brief Weights for data */
  const float* weights_;
Guolin Ke's avatar
Guolin Ke committed
103
104
  double softmax_weight_decay_;
  score_t hessian_nor_;
Guolin Ke's avatar
Guolin Ke committed
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
};

/*!
* \brief Objective function for multiclass classification, use one-vs-all binary objective function
*/
class MulticlassOVA: public ObjectiveFunction {
public:
  explicit MulticlassOVA(const ObjectiveConfig& config) {
    num_class_ = config.num_class;
    for (int i = 0; i < num_class_; ++i) {
      binary_loss_.emplace_back(
        new BinaryLogloss(config, [i](float label) { return static_cast<int>(label) == i; }));
    }
  }

  ~MulticlassOVA() {

  }

  void Init(const Metadata& metadata, data_size_t num_data) override {
    num_data_ = num_data;
    for (int i = 0; i < num_class_; ++i) {
      binary_loss_[i]->Init(metadata, num_data);
    }
  }

  void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override {
    for (int i = 0; i < num_class_; ++i) {
      int64_t bias = static_cast<int64_t>(num_data_) * i;
      binary_loss_[i]->GetGradients(score + bias, gradients + bias, hessians + bias);
    }
  }

  const char* GetName() const override {
    return "multiclassova";
  }

private:
  /*! \brief Number of data */
  data_size_t num_data_;
  /*! \brief Number of classes */
  int num_class_;
  std::vector<std::unique_ptr<BinaryLogloss>> binary_loss_;
148
149
150
151
};

}  // namespace LightGBM
#endif   // LightGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_