Commit 841a8987 authored by Guolin Ke's avatar Guolin Ke
Browse files

support OVA multi-classification.

parent 14195876
......@@ -138,7 +138,8 @@ void OverallConfig::GetTaskType(const std::unordered_map<std::string, std::strin
void OverallConfig::CheckParamConflict() {
// check if objective_type, metric_type, and num_class match
bool objective_type_multiclass = (objective_type == std::string("multiclass"));
bool objective_type_multiclass = (objective_type == std::string("multiclass")
|| objective_type == std::string("multiclassova"));
int num_class_check = boosting_config.num_class;
if (objective_type_multiclass) {
if (num_class_check <= 1) {
......@@ -151,11 +152,19 @@ void OverallConfig::CheckParamConflict() {
}
if (boosting_config.is_provide_training_metric || !io_config.valid_data_filenames.empty()) {
for (std::string metric_type : metric_types) {
bool metric_type_multiclass = (metric_type == std::string("multi_logloss") || metric_type == std::string("multi_error"));
bool metric_type_multiclass = (metric_type == std::string("multi_logloss")
|| metric_type == std::string("multi_error")
|| metric_type == std::string("multi_loglossova"));
if ((objective_type_multiclass && !metric_type_multiclass)
|| (!objective_type_multiclass && metric_type_multiclass)) {
Log::Fatal("Objective and metrics don't match");
}
if (objective_type == std::string("multiclassova") && metric_type == std::string("multi_logloss")) {
Log::Fatal("Wrong metric. For Multi-class with OVA, you should use multi_loglossova metric.");
}
if (objective_type == std::string("multiclass") && metric_type == std::string("multi_loglossova")) {
Log::Fatal("Wrong metric. For Multi-class with softmax, you should use multi_logloss metric.");
}
}
}
......
......@@ -29,7 +29,9 @@ Metric* Metric::CreateMetric(const std::string& type, const MetricConfig& config
} else if (type == std::string("map")) {
return new MapMetric(config);
} else if (type == std::string("multi_logloss")) {
return new MultiLoglossMetric(config);
return new MultiSoftmaxLoglossMetric(config);
} else if (type == std::string("multi_loglossova")) {
return new MultiOVALoglossMetric(config);
} else if (type == std::string("multi_error")) {
return new MultiErrorMetric(config);
}
......
......@@ -79,8 +79,6 @@ public:
}
private:
/*! \brief Output frequency */
int output_freq_;
/*! \brief Number of data */
data_size_t num_data_;
/*! \brief Number of classes */
......@@ -116,9 +114,9 @@ public:
};
/*! \brief Logloss for multiclass task */
class MultiLoglossMetric: public MulticlassMetric<MultiLoglossMetric> {
class MultiSoftmaxLoglossMetric: public MulticlassMetric<MultiSoftmaxLoglossMetric> {
public:
explicit MultiLoglossMetric(const MetricConfig& config) :MulticlassMetric<MultiLoglossMetric>(config) {}
explicit MultiSoftmaxLoglossMetric(const MetricConfig& config) :MulticlassMetric<MultiSoftmaxLoglossMetric>(config) {}
inline static double LossOnPoint(float label, std::vector<double>& score) {
size_t k = static_cast<size_t>(label);
......@@ -135,5 +133,84 @@ public:
}
};
class MultiOVALoglossMetric: public Metric {
public:
explicit MultiOVALoglossMetric(const MetricConfig& config) {
num_class_ = config.num_class;
sigmoid_ = config.sigmoid;
}
virtual ~MultiOVALoglossMetric() {
}
void Init(const Metadata& metadata, data_size_t num_data) override {
name_.emplace_back("multi_loglossova");
num_data_ = num_data;
// get label
label_ = metadata.label();
// get weights
weights_ = metadata.weights();
if (weights_ == nullptr) {
sum_weights_ = static_cast<double>(num_data_);
} else {
sum_weights_ = 0.0f;
for (data_size_t i = 0; i < num_data_; ++i) {
sum_weights_ += weights_[i];
}
}
}
const std::vector<std::string>& GetName() const override {
return name_;
}
double factor_to_bigger_better() const override {
return -1.0f;
}
std::vector<double> Eval(const double* score) const override {
double sum_loss = 0.0;
if (weights_ == nullptr) {
#pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) {
std::vector<double> rec(num_class_);
size_t idx = static_cast<size_t>(num_data_) * static_cast<int>(label_[i]) + i;
double prob = 1.0f / (1.0f + std::exp(-sigmoid_ * score[idx]));
if (prob < kEpsilon) { prob = kEpsilon; }
// add loss
sum_loss += -std::log(prob);
}
} else {
#pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) {
size_t idx = static_cast<size_t>(num_data_) * static_cast<int>(label_[i]) + i;
double prob = 1.0f / (1.0f + std::exp(-sigmoid_ * score[idx]));
if (prob < kEpsilon) { prob = kEpsilon; }
// add loss
sum_loss += -std::log(prob) * weights_[i];
}
}
double loss = sum_loss / sum_weights_;
return std::vector<double>(1, loss);
}
private:
/*! \brief Number of data */
data_size_t num_data_;
/*! \brief Number of classes */
int num_class_;
/*! \brief Pointer of label */
const float* label_;
/*! \brief Pointer of weighs */
const float* weights_;
/*! \brief Sum weights */
double sum_weights_;
/*! \brief Name of this test set */
std::vector<std::string> name_;
double sigmoid_;
};
} // namespace LightGBM
#endif // LightGBM_METRIC_MULTICLASS_METRIC_HPP_
......@@ -12,15 +12,21 @@ namespace LightGBM {
*/
class BinaryLogloss: public ObjectiveFunction {
public:
explicit BinaryLogloss(const ObjectiveConfig& config) {
explicit BinaryLogloss(const ObjectiveConfig& config, std::function<bool(float)> is_pos = nullptr) {
is_unbalance_ = config.is_unbalance;
sigmoid_ = static_cast<double>(config.sigmoid);
if (sigmoid_ <= 0.0) {
Log::Fatal("Sigmoid parameter %f should be greater than zero", sigmoid_);
}
scale_pos_weight_ = static_cast<double>(config.scale_pos_weight);
is_pos_ = is_pos;
if (is_pos_ == nullptr) {
is_pos_ = [](float label) {return label > 0; };
}
}
~BinaryLogloss() {}
void Init(const Metadata& metadata, data_size_t num_data) override {
num_data_ = num_data;
label_ = metadata.label();
......@@ -30,7 +36,7 @@ public:
// count for positive and negative samples
#pragma omp parallel for schedule(static) reduction(+:cnt_positive, cnt_negative)
for (data_size_t i = 0; i < num_data_; ++i) {
if (label_[i] > 0) {
if (is_pos_(label_[i])) {
++cnt_positive;
} else {
++cnt_negative;
......@@ -61,7 +67,7 @@ public:
#pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data_; ++i) {
// get label and label weights
const int is_pos = label_[i] > 0;
const int is_pos = is_pos_(label_[i]);
const int label = label_val_[is_pos];
const double label_weight = label_weights_[is_pos];
// calculate gradients and hessians
......@@ -74,7 +80,7 @@ public:
#pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data_; ++i) {
// get label and label weights
const int is_pos = label_[i] > 0;
const int is_pos = is_pos_(label_[i]);
const int label = label_val_[is_pos];
const double label_weight = label_weights_[is_pos];
// calculate gradients and hessians
......@@ -106,6 +112,7 @@ private:
/*! \brief Weights for data */
const float* weights_;
double scale_pos_weight_;
std::function<bool(float)> is_pos_;
};
} // namespace LightGBM
......
......@@ -5,19 +5,22 @@
#include <cstring>
#include <cmath>
#include <vector>
#include "binary_objective.hpp"
namespace LightGBM {
/*!
* \brief Objective function for multiclass classification
* \brief Objective function for multiclass classification, use softmax as objective functions
*/
class MulticlassLogloss: public ObjectiveFunction {
class MulticlassSoftmax: public ObjectiveFunction {
public:
explicit MulticlassLogloss(const ObjectiveConfig& config) {
explicit MulticlassSoftmax(const ObjectiveConfig& config) {
num_class_ = config.num_class;
is_unbalance_ = config.is_unbalance;
}
~MulticlassLogloss() {
~MulticlassSoftmax() {
}
void Init(const Metadata& metadata, data_size_t num_data) override {
......@@ -32,18 +35,6 @@ public:
Log::Fatal("Label must be in [0, %d), but found %d in label", num_class_, label_int_[i]);
}
}
label_pos_weights_ = std::vector<float>(num_class_, 1);
if (is_unbalance_) {
std::vector<int> cnts(num_class_, 0);
for (int i = 0; i < num_data_; ++i) {
++cnts[label_int_[i]];
}
for (int i = 0; i < num_class_; ++i) {
int cnt_cur = cnts[i];
int cnt_other = (num_data_ - cnts[i]);
label_pos_weights_[i] = static_cast<float>(cnt_other) / cnt_cur;
}
}
}
void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override {
......@@ -52,7 +43,7 @@ public:
#pragma omp parallel for schedule(static) private(rec)
for (data_size_t i = 0; i < num_data_; ++i) {
rec.resize(num_class_);
for (int k = 0; k < num_class_; ++k){
for (int k = 0; k < num_class_; ++k) {
size_t idx = static_cast<size_t>(num_data_) * k + i;
rec[k] = static_cast<double>(score[idx]);
}
......@@ -61,12 +52,11 @@ public:
auto p = rec[k];
size_t idx = static_cast<size_t>(num_data_) * k + i;
if (label_int_[i] == k) {
gradients[idx] = static_cast<score_t>(p - 1.0f) * label_pos_weights_[k];
hessians[idx] = static_cast<score_t>(p * (1.0f - p))* label_pos_weights_[k];
gradients[idx] = static_cast<score_t>(p - 1.0f);
} else {
gradients[idx] = static_cast<score_t>(p);
hessians[idx] = static_cast<score_t>(p * (1.0f - p));
}
hessians[idx] = static_cast<score_t>(p * (1.0f - p));
}
}
} else {
......@@ -74,7 +64,7 @@ public:
#pragma omp parallel for schedule(static) private(rec)
for (data_size_t i = 0; i < num_data_; ++i) {
rec.resize(num_class_);
for (int k = 0; k < num_class_; ++k){
for (int k = 0; k < num_class_; ++k) {
size_t idx = static_cast<size_t>(num_data_) * k + i;
rec[k] = static_cast<double>(score[idx]);
}
......@@ -83,13 +73,11 @@ public:
auto p = rec[k];
size_t idx = static_cast<size_t>(num_data_) * k + i;
if (label_int_[i] == k) {
gradients[idx] = static_cast<score_t>((p - 1.0f) * weights_[i]) * label_pos_weights_[k];
hessians[idx] = static_cast<score_t>(p * (1.0f - p) * weights_[i]) * label_pos_weights_[k];
gradients[idx] = static_cast<score_t>((p - 1.0f) * weights_[i]);
} else {
gradients[idx] = static_cast<score_t>(p * weights_[i]);
hessians[idx] = static_cast<score_t>(p * (1.0f - p) * weights_[i]);
}
hessians[idx] = static_cast<score_t>(p * (1.0f - p) * weights_[i]);
}
}
}
......@@ -110,9 +98,49 @@ private:
std::vector<int> label_int_;
/*! \brief Weights for data */
const float* weights_;
/*! \brief Weights for label */
std::vector<float> label_pos_weights_;
bool is_unbalance_;
};
/*!
* \brief Objective function for multiclass classification, use one-vs-all binary objective function
*/
class MulticlassOVA: public ObjectiveFunction {
public:
explicit MulticlassOVA(const ObjectiveConfig& config) {
num_class_ = config.num_class;
for (int i = 0; i < num_class_; ++i) {
binary_loss_.emplace_back(
new BinaryLogloss(config, [i](float label) { return static_cast<int>(label) == i; }));
}
}
~MulticlassOVA() {
}
void Init(const Metadata& metadata, data_size_t num_data) override {
num_data_ = num_data;
for (int i = 0; i < num_class_; ++i) {
binary_loss_[i]->Init(metadata, num_data);
}
}
void GetGradients(const double* score, score_t* gradients, score_t* hessians) const override {
for (int i = 0; i < num_class_; ++i) {
int64_t bias = static_cast<int64_t>(num_data_) * i;
binary_loss_[i]->GetGradients(score + bias, gradients + bias, hessians + bias);
}
}
const char* GetName() const override {
return "multiclassova";
}
private:
/*! \brief Number of data */
data_size_t num_data_;
/*! \brief Number of classes */
int num_class_;
std::vector<std::unique_ptr<BinaryLogloss>> binary_loss_;
};
} // namespace LightGBM
......
......@@ -23,7 +23,9 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string&
} else if (type == std::string("lambdarank")) {
return new LambdarankNDCG(config);
} else if (type == std::string("multiclass")) {
return new MulticlassLogloss(config);
return new MulticlassSoftmax(config);
} else if (type == std::string("multiclassova")) {
return new MulticlassOVA(config);
}
return nullptr;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment