"include/vscode:/vscode.git/clone" did not exist on "1c60cfa6a6dc8f24cfc2f37b0555cf1252366370"
multiclass_objective.hpp 3.92 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
#ifndef LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_
#define LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_

#include <LightGBM/objective_function.h>

#include <cstring>
#include <cmath>

namespace LightGBM {
/*!
* \brief Objective function for multiclass classification
*/
class MulticlassLogloss: public ObjectiveFunction {
public:
Guolin Ke's avatar
Guolin Ke committed
15
  explicit MulticlassLogloss(const ObjectiveConfig& config) {
16
    num_class_ = config.num_class;
17
    is_unbalance_ = config.is_unbalance;
18
  }
19

20
21
  ~MulticlassLogloss() {
  }
22

23
24
25
26
  void Init(const Metadata& metadata, data_size_t num_data) override {
    num_data_ = num_data;
    label_ = metadata.label();
    weights_ = metadata.weights();
Guolin Ke's avatar
Guolin Ke committed
27
    label_int_.resize(num_data_);
28
29
30
31
32
33
    #pragma omp parallel for schedule(static)
    for (int i = 0; i < num_data_; ++i) {
      label_int_[i] = static_cast<int>(label_[i]);
      if (label_int_[i] < 0 || label_int_[i] >= num_class_) {
        Log::Fatal("Label must be in [0, %d), but found %d in label", num_class_, label_int_[i]);
      }
34
    }
35
36
37
38
39
40
41
42
43
44
45
46
    label_pos_weights_ = std::vector<float>(num_class_, 1);
    if (is_unbalance_) {
      std::vector<int> cnts(num_class_, 0);
      for (int i = 0; i < num_data_; ++i) {
        ++cnts[label_int_[i]];
      }
      for (int i = 0; i < num_class_; ++i) {
        int cnt_cur = cnts[i];
        int cnt_other = (num_data_ - cnts[i]);
        label_pos_weights_[i] = static_cast<float>(cnt_other) / cnt_cur;
      }
    } 
47
48
  }

49
  void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override {
50
    if (weights_ == nullptr) {
zhangyafeikimi's avatar
zhangyafeikimi committed
51
52
      std::vector<double> rec;
      #pragma omp parallel for schedule(static) private(rec)
53
      for (data_size_t i = 0; i < num_data_; ++i) {
zhangyafeikimi's avatar
zhangyafeikimi committed
54
        rec.resize(num_class_);
55
        for (int k = 0; k < num_class_; ++k){
56
57
          size_t idx = static_cast<size_t>(num_data_) * k + i;
          rec[k] = static_cast<double>(score[idx]);
58
        }
59
        Common::Softmax(&rec);
60
        for (int k = 0; k < num_class_; ++k) {
61
          auto p = rec[k];
62
          size_t idx = static_cast<size_t>(num_data_) * k + i;
63
          if (label_int_[i] == k) {
64
            gradients[idx] = static_cast<score_t>(p - 1.0f) * label_pos_weights_[k];
65
            hessians[idx] = static_cast<score_t>(p * (1.0f - p))* label_pos_weights_[k];
66
          } else {
67
            gradients[idx] = static_cast<score_t>(p);
68
            hessians[idx] = static_cast<score_t>(p * (1.0f - p));
69
          }
70
        }
71
72
      }
    } else {
zhangyafeikimi's avatar
zhangyafeikimi committed
73
74
      std::vector<double> rec;
      #pragma omp parallel for schedule(static) private(rec)
75
      for (data_size_t i = 0; i < num_data_; ++i) {
zhangyafeikimi's avatar
zhangyafeikimi committed
76
        rec.resize(num_class_);
77
        for (int k = 0; k < num_class_; ++k){
78
79
          size_t idx = static_cast<size_t>(num_data_) * k + i;
          rec[k] = static_cast<double>(score[idx]);
80
        }
81
82
        Common::Softmax(&rec);
        for (int k = 0; k < num_class_; ++k) {
83
          auto p = rec[k];
84
          size_t idx = static_cast<size_t>(num_data_) * k + i;
85
          if (label_int_[i] == k) {
86
            gradients[idx] = static_cast<score_t>((p - 1.0f) * weights_[i]) * label_pos_weights_[k];
87
            hessians[idx] = static_cast<score_t>(p * (1.0f - p) * weights_[i]) * label_pos_weights_[k];
88
          } else {
89
            gradients[idx] = static_cast<score_t>(p * weights_[i]);
90
            hessians[idx] = static_cast<score_t>(p * (1.0f - p) * weights_[i]);
91
          }
92
          
93
94
95
96
97
        }
      }
    }
  }

Guolin Ke's avatar
Guolin Ke committed
98
99
  const char* GetName() const override {
    return "multiclass";
100
101
102
103
104
105
106
107
108
109
  }

private:
  /*! \brief Number of data */
  data_size_t num_data_;
  /*! \brief Number of classes */
  int num_class_;
  /*! \brief Pointer of label */
  const float* label_;
  /*! \brief Corresponding integers of label_ */
Guolin Ke's avatar
Guolin Ke committed
110
  std::vector<int> label_int_;
111
112
  /*! \brief Weights for data */
  const float* weights_;
113
114
115
  /*! \brief Weights for label */
  std::vector<float> label_pos_weights_;
  bool is_unbalance_;
116
117
118
119
};

}  // namespace LightGBM
#endif   // LightGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_