multiclass_objective.hpp 2.91 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
#ifndef LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_
#define LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_

#include <LightGBM/objective_function.h>

#include <cstring>
#include <cmath>

namespace LightGBM {
/*!
* \brief Objective function for multiclass classification
*/
class MulticlassLogloss: public ObjectiveFunction {
public:
  explicit MulticlassLogloss(const ObjectiveConfig& config)
        :label_int_(nullptr) {
    num_class_ = config.num_class;
  }
  
  ~MulticlassLogloss() {
    if (label_int_ != nullptr) { delete[] label_int_; }    
  }
  
  void Init(const Metadata& metadata, data_size_t num_data) override {
    num_data_ = num_data;
    label_ = metadata.label();
    weights_ = metadata.weights();
    label_int_ = new int[num_data_];
    for (int i = 0; i < num_data_; ++i){
        label_int_[i] = static_cast<int>(label_[i]); 
        if (label_int_[i] < 0 || label_int_[i] >= num_class_) {
            Log::Fatal("Label must be in [0, %d), but find %d in label", num_class_, label_int_[i]);
        }
    }
  }

  void GetGradients(const score_t* score, score_t* gradients, score_t* hessians) const override {
    if (weights_ == nullptr) {
      #pragma omp parallel for schedule(static)
      for (data_size_t i = 0; i < num_data_; ++i) {
41
        std::vector<double> rec(num_class_);
42
        for (int k = 0; k < num_class_; ++k){
43
          rec[k] = static_cast<double>(score[k * num_data_ + i]);
44
45
46
        }
        Common::Softmax(&rec);  
        for (int k = 0; k < num_class_; ++k) {
Guolin Ke's avatar
Guolin Ke committed
47
          score_t p = static_cast<score_t>(rec[k]);
48
49
50
51
52
53
54
55
56
57
58
          if (label_int_[i] == k) {
            gradients[k * num_data_ + i] = p - 1.0f;
          } else {
            gradients[k * num_data_ + i] = p;
          }
          hessians[k * num_data_ + i] = 2.0f * p * (1.0f - p);
        }  
      }
    } else {
      #pragma omp parallel for schedule(static)
      for (data_size_t i = 0; i < num_data_; ++i) {
59
        std::vector<double> rec(num_class_);
60
        for (int k = 0; k < num_class_; ++k){
61
          rec[k] = static_cast<double>(score[k * num_data_ + i]);
62
63
64
        }  
        Common::Softmax(&rec);
        for (int k = 0; k < num_class_; ++k) {
Guolin Ke's avatar
Guolin Ke committed
65
          score_t p = static_cast<score_t>(rec[k]);
66
67
68
69
70
71
72
73
74
75
76
          if (label_int_[i] == k) {
            gradients[k * num_data_ + i] = (p - 1.0f) * weights_[i];
          } else {
            gradients[k * num_data_ + i] = p * weights_[i];
          }
          hessians[k * num_data_ + i] = 2.0f * p * (1.0f - p) * weights_[i];
        }
      }
    }
  }

77
  score_t GetSigmoid() const override {
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
    return -1.0f;
  }

private:
  /*! \brief Number of data */
  data_size_t num_data_;
  /*! \brief Number of classes */
  int num_class_;
  /*! \brief Pointer of label */
  const float* label_;
  /*! \brief Corresponding integers of label_ */
  int* label_int_;
  /*! \brief Weights for data */
  const float* weights_;
};

}  // namespace LightGBM
#endif   // LightGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_