multiclass_objective.hpp 2.83 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
#ifndef LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_
#define LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_

#include <LightGBM/objective_function.h>

#include <cstring>
#include <cmath>

namespace LightGBM {
/*!
* \brief Objective function for multiclass classification
*/
class MulticlassLogloss: public ObjectiveFunction {
public:
Guolin Ke's avatar
Guolin Ke committed
15
  explicit MulticlassLogloss(const ObjectiveConfig& config) {
16
17
    num_class_ = config.num_class;
  }
18

19
20
  ~MulticlassLogloss() {
  }
21

22
23
24
25
  void Init(const Metadata& metadata, data_size_t num_data) override {
    num_data_ = num_data;
    label_ = metadata.label();
    weights_ = metadata.weights();
Guolin Ke's avatar
Guolin Ke committed
26
    label_int_.resize(num_data_);
27
    for (int i = 0; i < num_data_; ++i){
28
        label_int_[i] = static_cast<int>(label_[i]);
29
        if (label_int_[i] < 0 || label_int_[i] >= num_class_) {
30
            Log::Fatal("Label must be in [0, %d), but found %d in label", num_class_, label_int_[i]);
31
32
33
34
35
36
37
38
        }
    }
  }

  void GetGradients(const score_t* score, score_t* gradients, score_t* hessians) const override {
    if (weights_ == nullptr) {
      #pragma omp parallel for schedule(static)
      for (data_size_t i = 0; i < num_data_; ++i) {
39
        std::vector<double> rec(num_class_);
40
        for (int k = 0; k < num_class_; ++k){
41
          rec[k] = static_cast<double>(score[k * num_data_ + i]);
42
        }
43
        Common::Softmax(&rec);
44
        for (int k = 0; k < num_class_; ++k) {
Guolin Ke's avatar
Guolin Ke committed
45
          score_t p = static_cast<score_t>(rec[k]);
46
47
48
49
50
51
          if (label_int_[i] == k) {
            gradients[k * num_data_ + i] = p - 1.0f;
          } else {
            gradients[k * num_data_ + i] = p;
          }
          hessians[k * num_data_ + i] = 2.0f * p * (1.0f - p);
52
        }
53
54
55
56
      }
    } else {
      #pragma omp parallel for schedule(static)
      for (data_size_t i = 0; i < num_data_; ++i) {
57
        std::vector<double> rec(num_class_);
58
        for (int k = 0; k < num_class_; ++k){
59
          rec[k] = static_cast<double>(score[k * num_data_ + i]);
60
        }
61
62
        Common::Softmax(&rec);
        for (int k = 0; k < num_class_; ++k) {
Guolin Ke's avatar
Guolin Ke committed
63
          score_t p = static_cast<score_t>(rec[k]);
64
65
66
67
68
69
70
71
72
73
74
          if (label_int_[i] == k) {
            gradients[k * num_data_ + i] = (p - 1.0f) * weights_[i];
          } else {
            gradients[k * num_data_ + i] = p * weights_[i];
          }
          hessians[k * num_data_ + i] = 2.0f * p * (1.0f - p) * weights_[i];
        }
      }
    }
  }

Guolin Ke's avatar
Guolin Ke committed
75
76
  const char* GetName() const override {
    return "multiclass";
77
78
79
80
81
82
83
84
85
86
  }

private:
  /*! \brief Number of data */
  data_size_t num_data_;
  /*! \brief Number of classes */
  int num_class_;
  /*! \brief Pointer of label */
  const float* label_;
  /*! \brief Corresponding integers of label_ */
Guolin Ke's avatar
Guolin Ke committed
87
  std::vector<int> label_int_;
88
89
90
91
92
93
  /*! \brief Weights for data */
  const float* weights_;
};

}  // namespace LightGBM
#endif   // LightGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_