Commit 841a8987 authored by Guolin Ke's avatar Guolin Ke
Browse files

support OVA multi-classification.

parent 14195876
...@@ -138,7 +138,8 @@ void OverallConfig::GetTaskType(const std::unordered_map<std::string, std::strin ...@@ -138,7 +138,8 @@ void OverallConfig::GetTaskType(const std::unordered_map<std::string, std::strin
void OverallConfig::CheckParamConflict() { void OverallConfig::CheckParamConflict() {
// check if objective_type, metric_type, and num_class match // check if objective_type, metric_type, and num_class match
bool objective_type_multiclass = (objective_type == std::string("multiclass")); bool objective_type_multiclass = (objective_type == std::string("multiclass")
|| objective_type == std::string("multiclassova"));
int num_class_check = boosting_config.num_class; int num_class_check = boosting_config.num_class;
if (objective_type_multiclass) { if (objective_type_multiclass) {
if (num_class_check <= 1) { if (num_class_check <= 1) {
...@@ -151,11 +152,19 @@ void OverallConfig::CheckParamConflict() { ...@@ -151,11 +152,19 @@ void OverallConfig::CheckParamConflict() {
} }
if (boosting_config.is_provide_training_metric || !io_config.valid_data_filenames.empty()) { if (boosting_config.is_provide_training_metric || !io_config.valid_data_filenames.empty()) {
for (std::string metric_type : metric_types) { for (std::string metric_type : metric_types) {
bool metric_type_multiclass = (metric_type == std::string("multi_logloss") || metric_type == std::string("multi_error")); bool metric_type_multiclass = (metric_type == std::string("multi_logloss")
|| metric_type == std::string("multi_error")
|| metric_type == std::string("multi_loglossova"));
if ((objective_type_multiclass && !metric_type_multiclass) if ((objective_type_multiclass && !metric_type_multiclass)
|| (!objective_type_multiclass && metric_type_multiclass)) { || (!objective_type_multiclass && metric_type_multiclass)) {
Log::Fatal("Objective and metrics don't match"); Log::Fatal("Objective and metrics don't match");
} }
if (objective_type == std::string("multiclassova") && metric_type == std::string("multi_logloss")) {
Log::Fatal("Wrong metric. For Multi-class with OVA, you should use multi_loglossova metric.");
}
if (objective_type == std::string("multiclass") && metric_type == std::string("multi_loglossova")) {
Log::Fatal("Wrong metric. For Multi-class with softmax, you should use multi_logloss metric.");
}
} }
} }
......
...@@ -29,7 +29,9 @@ Metric* Metric::CreateMetric(const std::string& type, const MetricConfig& config ...@@ -29,7 +29,9 @@ Metric* Metric::CreateMetric(const std::string& type, const MetricConfig& config
} else if (type == std::string("map")) { } else if (type == std::string("map")) {
return new MapMetric(config); return new MapMetric(config);
} else if (type == std::string("multi_logloss")) { } else if (type == std::string("multi_logloss")) {
return new MultiLoglossMetric(config); return new MultiSoftmaxLoglossMetric(config);
} else if (type == std::string("multi_loglossova")) {
return new MultiOVALoglossMetric(config);
} else if (type == std::string("multi_error")) { } else if (type == std::string("multi_error")) {
return new MultiErrorMetric(config); return new MultiErrorMetric(config);
} }
......
...@@ -79,8 +79,6 @@ public: ...@@ -79,8 +79,6 @@ public:
} }
private: private:
/*! \brief Output frequency */
int output_freq_;
/*! \brief Number of data */ /*! \brief Number of data */
data_size_t num_data_; data_size_t num_data_;
/*! \brief Number of classes */ /*! \brief Number of classes */
...@@ -116,9 +114,9 @@ public: ...@@ -116,9 +114,9 @@ public:
}; };
/*! \brief Logloss for multiclass task */ /*! \brief Logloss for multiclass task */
class MultiLoglossMetric: public MulticlassMetric<MultiLoglossMetric> { class MultiSoftmaxLoglossMetric: public MulticlassMetric<MultiSoftmaxLoglossMetric> {
public: public:
explicit MultiLoglossMetric(const MetricConfig& config) :MulticlassMetric<MultiLoglossMetric>(config) {} explicit MultiSoftmaxLoglossMetric(const MetricConfig& config) :MulticlassMetric<MultiSoftmaxLoglossMetric>(config) {}
inline static double LossOnPoint(float label, std::vector<double>& score) { inline static double LossOnPoint(float label, std::vector<double>& score) {
size_t k = static_cast<size_t>(label); size_t k = static_cast<size_t>(label);
...@@ -135,5 +133,84 @@ public: ...@@ -135,5 +133,84 @@ public:
} }
}; };
class MultiOVALoglossMetric: public Metric {
public:
explicit MultiOVALoglossMetric(const MetricConfig& config) {
num_class_ = config.num_class;
sigmoid_ = config.sigmoid;
}
virtual ~MultiOVALoglossMetric() {
}
void Init(const Metadata& metadata, data_size_t num_data) override {
name_.emplace_back("multi_loglossova");
num_data_ = num_data;
// get label
label_ = metadata.label();
// get weights
weights_ = metadata.weights();
if (weights_ == nullptr) {
sum_weights_ = static_cast<double>(num_data_);
} else {
sum_weights_ = 0.0f;
for (data_size_t i = 0; i < num_data_; ++i) {
sum_weights_ += weights_[i];
}
}
}
const std::vector<std::string>& GetName() const override {
return name_;
}
double factor_to_bigger_better() const override {
return -1.0f;
}
std::vector<double> Eval(const double* score) const override {
double sum_loss = 0.0;
if (weights_ == nullptr) {
#pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) {
std::vector<double> rec(num_class_);
size_t idx = static_cast<size_t>(num_data_) * static_cast<int>(label_[i]) + i;
double prob = 1.0f / (1.0f + std::exp(-sigmoid_ * score[idx]));
if (prob < kEpsilon) { prob = kEpsilon; }
// add loss
sum_loss += -std::log(prob);
}
} else {
#pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) {
size_t idx = static_cast<size_t>(num_data_) * static_cast<int>(label_[i]) + i;
double prob = 1.0f / (1.0f + std::exp(-sigmoid_ * score[idx]));
if (prob < kEpsilon) { prob = kEpsilon; }
// add loss
sum_loss += -std::log(prob) * weights_[i];
}
}
double loss = sum_loss / sum_weights_;
return std::vector<double>(1, loss);
}
private:
/*! \brief Number of data */
data_size_t num_data_;
/*! \brief Number of classes */
int num_class_;
/*! \brief Pointer of label */
const float* label_;
/*! \brief Pointer of weighs */
const float* weights_;
/*! \brief Sum weights */
double sum_weights_;
/*! \brief Name of this test set */
std::vector<std::string> name_;
double sigmoid_;
};
} // namespace LightGBM } // namespace LightGBM
#endif // LightGBM_METRIC_MULTICLASS_METRIC_HPP_ #endif // LightGBM_METRIC_MULTICLASS_METRIC_HPP_
...@@ -12,15 +12,21 @@ namespace LightGBM { ...@@ -12,15 +12,21 @@ namespace LightGBM {
*/ */
class BinaryLogloss: public ObjectiveFunction { class BinaryLogloss: public ObjectiveFunction {
public: public:
explicit BinaryLogloss(const ObjectiveConfig& config) { explicit BinaryLogloss(const ObjectiveConfig& config, std::function<bool(float)> is_pos = nullptr) {
is_unbalance_ = config.is_unbalance; is_unbalance_ = config.is_unbalance;
sigmoid_ = static_cast<double>(config.sigmoid); sigmoid_ = static_cast<double>(config.sigmoid);
if (sigmoid_ <= 0.0) { if (sigmoid_ <= 0.0) {
Log::Fatal("Sigmoid parameter %f should be greater than zero", sigmoid_); Log::Fatal("Sigmoid parameter %f should be greater than zero", sigmoid_);
} }
scale_pos_weight_ = static_cast<double>(config.scale_pos_weight); scale_pos_weight_ = static_cast<double>(config.scale_pos_weight);
is_pos_ = is_pos;
if (is_pos_ == nullptr) {
is_pos_ = [](float label) {return label > 0; };
}
} }
~BinaryLogloss() {} ~BinaryLogloss() {}
void Init(const Metadata& metadata, data_size_t num_data) override { void Init(const Metadata& metadata, data_size_t num_data) override {
num_data_ = num_data; num_data_ = num_data;
label_ = metadata.label(); label_ = metadata.label();
...@@ -30,7 +36,7 @@ public: ...@@ -30,7 +36,7 @@ public:
// count for positive and negative samples // count for positive and negative samples
#pragma omp parallel for schedule(static) reduction(+:cnt_positive, cnt_negative) #pragma omp parallel for schedule(static) reduction(+:cnt_positive, cnt_negative)
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
if (label_[i] > 0) { if (is_pos_(label_[i])) {
++cnt_positive; ++cnt_positive;
} else { } else {
++cnt_negative; ++cnt_negative;
...@@ -61,7 +67,7 @@ public: ...@@ -61,7 +67,7 @@ public:
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
// get label and label weights // get label and label weights
const int is_pos = label_[i] > 0; const int is_pos = is_pos_(label_[i]);
const int label = label_val_[is_pos]; const int label = label_val_[is_pos];
const double label_weight = label_weights_[is_pos]; const double label_weight = label_weights_[is_pos];
// calculate gradients and hessians // calculate gradients and hessians
...@@ -74,7 +80,7 @@ public: ...@@ -74,7 +80,7 @@ public:
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
// get label and label weights // get label and label weights
const int is_pos = label_[i] > 0; const int is_pos = is_pos_(label_[i]);
const int label = label_val_[is_pos]; const int label = label_val_[is_pos];
const double label_weight = label_weights_[is_pos]; const double label_weight = label_weights_[is_pos];
// calculate gradients and hessians // calculate gradients and hessians
...@@ -106,6 +112,7 @@ private: ...@@ -106,6 +112,7 @@ private:
/*! \brief Weights for data */ /*! \brief Weights for data */
const float* weights_; const float* weights_;
double scale_pos_weight_; double scale_pos_weight_;
std::function<bool(float)> is_pos_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
...@@ -5,19 +5,22 @@ ...@@ -5,19 +5,22 @@
#include <cstring> #include <cstring>
#include <cmath> #include <cmath>
#include <vector>
#include "binary_objective.hpp"
namespace LightGBM { namespace LightGBM {
/*! /*!
* \brief Objective function for multiclass classification * \brief Objective function for multiclass classification, use softmax as objective functions
*/ */
class MulticlassLogloss: public ObjectiveFunction { class MulticlassSoftmax: public ObjectiveFunction {
public: public:
explicit MulticlassLogloss(const ObjectiveConfig& config) { explicit MulticlassSoftmax(const ObjectiveConfig& config) {
num_class_ = config.num_class; num_class_ = config.num_class;
is_unbalance_ = config.is_unbalance;
} }
~MulticlassLogloss() { ~MulticlassSoftmax() {
} }
void Init(const Metadata& metadata, data_size_t num_data) override { void Init(const Metadata& metadata, data_size_t num_data) override {
...@@ -32,18 +35,6 @@ public: ...@@ -32,18 +35,6 @@ public:
Log::Fatal("Label must be in [0, %d), but found %d in label", num_class_, label_int_[i]); Log::Fatal("Label must be in [0, %d), but found %d in label", num_class_, label_int_[i]);
} }
} }
label_pos_weights_ = std::vector<float>(num_class_, 1);
if (is_unbalance_) {
std::vector<int> cnts(num_class_, 0);
for (int i = 0; i < num_data_; ++i) {
++cnts[label_int_[i]];
}
for (int i = 0; i < num_class_; ++i) {
int cnt_cur = cnts[i];
int cnt_other = (num_data_ - cnts[i]);
label_pos_weights_[i] = static_cast<float>(cnt_other) / cnt_cur;
}
}
} }
void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override { void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override {
...@@ -52,7 +43,7 @@ public: ...@@ -52,7 +43,7 @@ public:
#pragma omp parallel for schedule(static) private(rec) #pragma omp parallel for schedule(static) private(rec)
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
rec.resize(num_class_); rec.resize(num_class_);
for (int k = 0; k < num_class_; ++k){ for (int k = 0; k < num_class_; ++k) {
size_t idx = static_cast<size_t>(num_data_) * k + i; size_t idx = static_cast<size_t>(num_data_) * k + i;
rec[k] = static_cast<double>(score[idx]); rec[k] = static_cast<double>(score[idx]);
} }
...@@ -61,12 +52,11 @@ public: ...@@ -61,12 +52,11 @@ public:
auto p = rec[k]; auto p = rec[k];
size_t idx = static_cast<size_t>(num_data_) * k + i; size_t idx = static_cast<size_t>(num_data_) * k + i;
if (label_int_[i] == k) { if (label_int_[i] == k) {
gradients[idx] = static_cast<score_t>(p - 1.0f) * label_pos_weights_[k]; gradients[idx] = static_cast<score_t>(p - 1.0f);
hessians[idx] = static_cast<score_t>(p * (1.0f - p))* label_pos_weights_[k];
} else { } else {
gradients[idx] = static_cast<score_t>(p); gradients[idx] = static_cast<score_t>(p);
hessians[idx] = static_cast<score_t>(p * (1.0f - p));
} }
hessians[idx] = static_cast<score_t>(p * (1.0f - p));
} }
} }
} else { } else {
...@@ -74,7 +64,7 @@ public: ...@@ -74,7 +64,7 @@ public:
#pragma omp parallel for schedule(static) private(rec) #pragma omp parallel for schedule(static) private(rec)
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
rec.resize(num_class_); rec.resize(num_class_);
for (int k = 0; k < num_class_; ++k){ for (int k = 0; k < num_class_; ++k) {
size_t idx = static_cast<size_t>(num_data_) * k + i; size_t idx = static_cast<size_t>(num_data_) * k + i;
rec[k] = static_cast<double>(score[idx]); rec[k] = static_cast<double>(score[idx]);
} }
...@@ -83,13 +73,11 @@ public: ...@@ -83,13 +73,11 @@ public:
auto p = rec[k]; auto p = rec[k];
size_t idx = static_cast<size_t>(num_data_) * k + i; size_t idx = static_cast<size_t>(num_data_) * k + i;
if (label_int_[i] == k) { if (label_int_[i] == k) {
gradients[idx] = static_cast<score_t>((p - 1.0f) * weights_[i]) * label_pos_weights_[k]; gradients[idx] = static_cast<score_t>((p - 1.0f) * weights_[i]);
hessians[idx] = static_cast<score_t>(p * (1.0f - p) * weights_[i]) * label_pos_weights_[k];
} else { } else {
gradients[idx] = static_cast<score_t>(p * weights_[i]); gradients[idx] = static_cast<score_t>(p * weights_[i]);
hessians[idx] = static_cast<score_t>(p * (1.0f - p) * weights_[i]);
} }
hessians[idx] = static_cast<score_t>(p * (1.0f - p) * weights_[i]);
} }
} }
} }
...@@ -110,9 +98,49 @@ private: ...@@ -110,9 +98,49 @@ private:
std::vector<int> label_int_; std::vector<int> label_int_;
/*! \brief Weights for data */ /*! \brief Weights for data */
const float* weights_; const float* weights_;
/*! \brief Weights for label */ };
std::vector<float> label_pos_weights_;
bool is_unbalance_; /*!
* \brief Objective function for multiclass classification, use one-vs-all binary objective function
*/
class MulticlassOVA: public ObjectiveFunction {
public:
explicit MulticlassOVA(const ObjectiveConfig& config) {
num_class_ = config.num_class;
for (int i = 0; i < num_class_; ++i) {
binary_loss_.emplace_back(
new BinaryLogloss(config, [i](float label) { return static_cast<int>(label) == i; }));
}
}
~MulticlassOVA() {
}
void Init(const Metadata& metadata, data_size_t num_data) override {
num_data_ = num_data;
for (int i = 0; i < num_class_; ++i) {
binary_loss_[i]->Init(metadata, num_data);
}
}
void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override {
for (int i = 0; i < num_class_; ++i) {
int64_t bias = static_cast<int64_t>(num_data_) * i;
binary_loss_[i]->GetGradients(score + bias, gradients + bias, hessians + bias);
}
}
const char* GetName() const override {
return "multiclassova";
}
private:
/*! \brief Number of data */
data_size_t num_data_;
/*! \brief Number of classes */
int num_class_;
std::vector<std::unique_ptr<BinaryLogloss>> binary_loss_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
...@@ -23,7 +23,9 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& ...@@ -23,7 +23,9 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string&
} else if (type == std::string("lambdarank")) { } else if (type == std::string("lambdarank")) {
return new LambdarankNDCG(config); return new LambdarankNDCG(config);
} else if (type == std::string("multiclass")) { } else if (type == std::string("multiclass")) {
return new MulticlassLogloss(config); return new MulticlassSoftmax(config);
} else if (type == std::string("multiclassova")) {
return new MulticlassOVA(config);
} }
return nullptr; return nullptr;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment