multiclass_objective.hpp 3.89 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
#ifndef LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_
#define LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_

#include <LightGBM/objective_function.h>

#include <cstring>
#include <cmath>

namespace LightGBM {
/*!
* \brief Objective function for multiclass classification
*/
class MulticlassLogloss: public ObjectiveFunction {
public:
Guolin Ke's avatar
Guolin Ke committed
15
  explicit MulticlassLogloss(const ObjectiveConfig& config) {
16
    num_class_ = config.num_class;
17
    is_unbalance_ = config.is_unbalance;
18
  }
19

20
21
  ~MulticlassLogloss() {
  }
22

23
24
25
26
  void Init(const Metadata& metadata, data_size_t num_data) override {
    num_data_ = num_data;
    label_ = metadata.label();
    weights_ = metadata.weights();
Guolin Ke's avatar
Guolin Ke committed
27
    label_int_.resize(num_data_);
28
29
30
31
32
33
    #pragma omp parallel for schedule(static)
    for (int i = 0; i < num_data_; ++i) {
      label_int_[i] = static_cast<int>(label_[i]);
      if (label_int_[i] < 0 || label_int_[i] >= num_class_) {
        Log::Fatal("Label must be in [0, %d), but found %d in label", num_class_, label_int_[i]);
      }
34
    }
35
36
37
38
39
40
41
42
43
44
45
46
    label_pos_weights_ = std::vector<float>(num_class_, 1);
    if (is_unbalance_) {
      std::vector<int> cnts(num_class_, 0);
      for (int i = 0; i < num_data_; ++i) {
        ++cnts[label_int_[i]];
      }
      for (int i = 0; i < num_class_; ++i) {
        int cnt_cur = cnts[i];
        int cnt_other = (num_data_ - cnts[i]);
        label_pos_weights_[i] = static_cast<float>(cnt_other) / cnt_cur;
      }
    } 
47
48
  }

49
  void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override {
50
51
52
    if (weights_ == nullptr) {
      #pragma omp parallel for schedule(static)
      for (data_size_t i = 0; i < num_data_; ++i) {
53
        std::vector<double> rec(num_class_);
54
        for (int k = 0; k < num_class_; ++k){
55
56
          size_t idx = static_cast<size_t>(num_data_) * k + i;
          rec[k] = static_cast<double>(score[idx]);
57
        }
58
        Common::Softmax(&rec);
59
        for (int k = 0; k < num_class_; ++k) {
60
          auto p = rec[k];
61
          size_t idx = static_cast<size_t>(num_data_) * k + i;
62
          if (label_int_[i] == k) {
63
64
            gradients[idx] = static_cast<score_t>(p - 1.0f) * label_pos_weights_[k];
            hessians[idx] = static_cast<score_t>(2.0f * p * (1.0f - p))* label_pos_weights_[k];
65
          } else {
66
            gradients[idx] = static_cast<score_t>(p);
67
            hessians[idx] = static_cast<score_t>(2.0f * p * (1.0f - p));
68
          }
69
        }
70
71
72
73
      }
    } else {
      #pragma omp parallel for schedule(static)
      for (data_size_t i = 0; i < num_data_; ++i) {
74
        std::vector<double> rec(num_class_);
75
        for (int k = 0; k < num_class_; ++k){
76
77
          size_t idx = static_cast<size_t>(num_data_) * k + i;
          rec[k] = static_cast<double>(score[idx]);
78
        }
79
80
        Common::Softmax(&rec);
        for (int k = 0; k < num_class_; ++k) {
81
          auto p = rec[k];
82
          size_t idx = static_cast<size_t>(num_data_) * k + i;
83
          if (label_int_[i] == k) {
84
85
            gradients[idx] = static_cast<score_t>((p - 1.0f) * weights_[i]) * label_pos_weights_[k];
            hessians[idx] = static_cast<score_t>(2.0f * p * (1.0f - p) * weights_[i]) * label_pos_weights_[k];
86
          } else {
87
            gradients[idx] = static_cast<score_t>(p * weights_[i]);
88
            hessians[idx] = static_cast<score_t>(2.0f * p * (1.0f - p) * weights_[i]);
89
          }
90
          
91
92
93
94
95
        }
      }
    }
  }

Guolin Ke's avatar
Guolin Ke committed
96
97
  const char* GetName() const override {
    return "multiclass";
98
99
100
101
102
103
104
105
106
107
  }

private:
  /*! \brief Number of data */
  data_size_t num_data_;
  /*! \brief Number of classes */
  int num_class_;
  /*! \brief Pointer of label */
  const float* label_;
  /*! \brief Corresponding integers of label_ */
Guolin Ke's avatar
Guolin Ke committed
108
  std::vector<int> label_int_;
109
110
  /*! \brief Weights for data */
  const float* weights_;
111
112
113
  /*! \brief Weights for label */
  std::vector<float> label_pos_weights_;
  bool is_unbalance_;
114
115
116
117
};

}  // namespace LightGBM
#endif   // LightGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_