multiclass_objective.hpp 2.96 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
#ifndef LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_
#define LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_

#include <LightGBM/objective_function.h>

#include <cstring>
#include <cmath>

namespace LightGBM {
/*!
* \brief Objective function for multiclass classification
*/
class MulticlassLogloss: public ObjectiveFunction {
public:
Guolin Ke's avatar
Guolin Ke committed
15
  explicit MulticlassLogloss(const ObjectiveConfig& config) {
16
17
    num_class_ = config.num_class;
  }
18

19
20
  ~MulticlassLogloss() {
  }
21

22
23
24
25
  void Init(const Metadata& metadata, data_size_t num_data) override {
    num_data_ = num_data;
    label_ = metadata.label();
    weights_ = metadata.weights();
Guolin Ke's avatar
Guolin Ke committed
26
    label_int_.resize(num_data_);
27
    for (int i = 0; i < num_data_; ++i){
28
        label_int_[i] = static_cast<int>(label_[i]);
29
        if (label_int_[i] < 0 || label_int_[i] >= num_class_) {
30
            Log::Fatal("Label must be in [0, %d), but found %d in label", num_class_, label_int_[i]);
31
32
33
34
35
36
37
38
        }
    }
  }

  void GetGradients(const score_t* score, score_t* gradients, score_t* hessians) const override {
    if (weights_ == nullptr) {
      #pragma omp parallel for schedule(static)
      for (data_size_t i = 0; i < num_data_; ++i) {
39
        std::vector<double> rec(num_class_);
40
        for (int k = 0; k < num_class_; ++k){
41
42
          size_t idx = static_cast<size_t>(num_data_) * k + i;
          rec[k] = static_cast<double>(score[idx]);
43
        }
44
        Common::Softmax(&rec);
45
        for (int k = 0; k < num_class_; ++k) {
Guolin Ke's avatar
Guolin Ke committed
46
          score_t p = static_cast<score_t>(rec[k]);
47
          size_t idx = static_cast<size_t>(num_data_) * k + i;
48
          if (label_int_[i] == k) {
49
            gradients[idx] = p - 1.0f;
50
          } else {
51
            gradients[idx] = p;
52
          }
53
          hessians[idx] = 2.0f * p * (1.0f - p);
54
        }
55
56
57
58
      }
    } else {
      #pragma omp parallel for schedule(static)
      for (data_size_t i = 0; i < num_data_; ++i) {
59
        std::vector<double> rec(num_class_);
60
        for (int k = 0; k < num_class_; ++k){
61
62
          size_t idx = static_cast<size_t>(num_data_) * k + i;
          rec[k] = static_cast<double>(score[idx]);
63
        }
64
65
        Common::Softmax(&rec);
        for (int k = 0; k < num_class_; ++k) {
Guolin Ke's avatar
Guolin Ke committed
66
          score_t p = static_cast<score_t>(rec[k]);
67
          size_t idx = static_cast<size_t>(num_data_) * k + i;
68
          if (label_int_[i] == k) {
69
            gradients[idx] = (p - 1.0f) * weights_[i];
70
          } else {
71
            gradients[idx] = p * weights_[i];
72
          }
73
          hessians[idx] = 2.0f * p * (1.0f - p) * weights_[i];
74
75
76
77
78
        }
      }
    }
  }

Guolin Ke's avatar
Guolin Ke committed
79
80
  const char* GetName() const override {
    return "multiclass";
81
82
83
84
85
86
87
88
89
90
  }

private:
  /*! \brief Number of data */
  data_size_t num_data_;
  /*! \brief Number of classes */
  int num_class_;
  /*! \brief Pointer of label */
  const float* label_;
  /*! \brief Corresponding integers of label_ */
Guolin Ke's avatar
Guolin Ke committed
91
  std::vector<int> label_int_;
92
93
94
95
96
97
  /*! \brief Weights for data */
  const float* weights_;
};

}  // namespace LightGBM
#endif   // LightGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_