Commit aa796a85 authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

Support multiclass classification (#53)

Support multiclass classification (#53)
parent 90ffe1c9
This diff is collapsed.
This diff is collapsed.
task = predict
data = multiclass.test
input_model= LightGBM_model.txt
# task type, support train and predict
task = train
# boosting type, support gbdt for now, alias: boosting, boost
boosting_type = gbdt
# application type, support following application
# regression , regression task
# binary , binary classification task
# lambdarank , lambdarank task
# multiclass
# alias: application, app
objective = multiclass
# eval metrics, support multi metric, delimite by ',' , support following metrics
# l1
# l2 , default metric for regression
# ndcg , default metric for lambdarank
# auc
# binary_logloss , default metric for binary
# binary_error
# multi_logloss
# multi_error
metric = multi_logloss
# number of class, for multiclass classification
num_class = 5
# frequence for metric output
metric_freq = 1
# true if need output metric for training data, alias: tranining_metric, train_metric
is_training_metric = true
# number of bins for feature bucket, 255 is a recommend setting, it can save memories, and also has good accuracy.
max_bin = 255
# training data
# if exsting weight file, should name to "regression.train.weight"
# alias: train_data, train
data = multiclass.train
# valid data
valid_data = multiclass.test
# round for early stopping
early_stopping = 10
# number of trees(iterations), alias: num_tree, num_iteration, num_iterations, num_round, num_rounds
num_trees = 100
# shrinkage rate , alias: shrinkage_rate
learning_rate = 0.05
# number of leaves for one tree, alias: num_leaf
num_leaves = 31
......@@ -80,6 +80,13 @@ public:
const float* feature_values,
int num_used_model) const = 0;
/*!
* \brief Predtion for multiclass classification
* \param feature_values Feature value on this record
* \return Prediction result, num_class numbers per line
*/
virtual std::vector<float> PredictMulticlass(const float* value, int num_used_model) const = 0;
/*!
* \brief save model to file
*/
......@@ -108,7 +115,13 @@ public:
* \return Number of weak sub-models
*/
virtual int NumberOfSubModels() const = 0;
/*!
* \brief Get number of classes
* \return Number of classes
*/
virtual int NumberOfClass() const = 0;
/*!
* \brief Get Type name of this boosting object
*/
......
......@@ -128,6 +128,8 @@ public:
int max_position = 20;
// for binary
bool is_unbalance = false;
// for multiclass
int num_class = 1;
void Set(const std::unordered_map<std::string, std::string>& params) override;
};
......@@ -135,6 +137,7 @@ public:
struct MetricConfig: public ConfigBase {
public:
virtual ~MetricConfig() {}
int num_class = 1;
float sigmoid = 1.0f;
std::vector<float> label_gain;
std::vector<int> eval_at;
......@@ -179,6 +182,7 @@ public:
int bagging_seed = 3;
int bagging_freq = 0;
int early_stopping_round = 0;
int num_class = 1;
void Set(const std::unordered_map<std::string, std::string>& params) override;
};
......
......@@ -336,6 +336,26 @@ static inline int64_t Pow2RoundUp(int64_t x) {
return 0;
}
/*!
* \brief Do inplace softmax transformaton on p_rec
* \param p_rec The input/output vector of the values.
*/
inline void Softmax(std::vector<float>* p_rec) {
std::vector<float> &rec = *p_rec;
float wmax = rec[0];
for (size_t i = 1; i < rec.size(); ++i) {
wmax = std::max(rec[i], wmax);
}
float wsum = 0.0f;
for (size_t i = 0; i < rec.size(); ++i) {
rec[i] = std::exp(rec[i] - wmax);
wsum += rec[i];
}
for (size_t i = 0; i < rec.size(); ++i) {
rec[i] /= static_cast<float>(wsum);
}
}
} // namespace Common
} // namespace LightGBM
......
......@@ -33,6 +33,7 @@ public:
num_used_model_(num_used_model) {
boosting_ = boosting;
num_features_ = boosting_->MaxFeatureIdx() + 1;
num_class_ = boosting_->NumberOfClass();
#pragma omp parallel
#pragma omp master
{
......@@ -87,6 +88,18 @@ public:
// get result with sigmoid transform if needed
return boosting_->Predict(features_[tid], num_used_model_);
}
/*!
* \brief prediction for multiclass classification
* \param features Feature of this record
* \return Prediction result
*/
std::vector<float> PredictMulticlassOneLine(const std::vector<std::pair<int, float>>& features) {
const int tid = PutFeatureValuesToBuffer(features);
// get result with sigmoid transform if needed
return boosting_->PredictMulticlass(features_[tid], num_used_model_);
}
/*!
* \brief predicting on data, then saving result to disk
* \param data_filename Filename of data
......@@ -120,17 +133,30 @@ public:
};
std::function<std::string(const std::vector<std::pair<int, float>>&)> predict_fun;
if (is_predict_leaf_index_) {
if (num_class_ > 1) {
predict_fun = [this](const std::vector<std::pair<int, float>>& features){
std::vector<float> prediction = PredictMulticlassOneLine(features);
std::stringstream result_stream_buf;
for (size_t i = 0; i < prediction.size(); ++i){
if (i > 0) {
result_stream_buf << '\t';
}
result_stream_buf << prediction[i];
}
return result_stream_buf.str();
};
}
else if (is_predict_leaf_index_) {
predict_fun = [this](const std::vector<std::pair<int, float>>& features){
std::vector<int> predicted_leaf_index = PredictLeafIndexOneLine(features);
std::stringstream result_ss;
std::stringstream result_stream_buf;
for (size_t i = 0; i < predicted_leaf_index.size(); ++i){
if (i > 0) {
result_ss << '\t';
result_stream_buf << '\t';
}
result_ss << predicted_leaf_index[i];
result_stream_buf << predicted_leaf_index[i];
}
return result_ss.str();
return result_stream_buf.str();
};
}
else {
......@@ -189,6 +215,8 @@ private:
float** features_;
/*! \brief Number of features */
int num_features_;
/*! \brief Number of classes */
int num_class_;
/*! \brief True if need to predict result with sigmoid transform */
bool is_simgoid_;
/*! \brief Number of threads */
......
......@@ -17,13 +17,15 @@
namespace LightGBM {
GBDT::GBDT()
: tree_learner_(nullptr), train_score_updater_(nullptr),
: train_score_updater_(nullptr),
gradients_(nullptr), hessians_(nullptr),
out_of_bag_data_indices_(nullptr), bag_data_indices_(nullptr) {
}
GBDT::~GBDT() {
if (tree_learner_ != nullptr) { delete tree_learner_; }
for (auto& tree_learner: tree_learner_){
if (tree_learner != nullptr) { delete tree_learner; }
}
if (gradients_ != nullptr) { delete[] gradients_; }
if (hessians_ != nullptr) { delete[] hessians_; }
if (out_of_bag_data_indices_ != nullptr) { delete[] out_of_bag_data_indices_; }
......@@ -44,23 +46,27 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
max_feature_idx_ = 0;
early_stopping_round_ = gbdt_config_->early_stopping_round;
train_data_ = train_data;
num_class_ = config->num_class;
tree_learner_ = std::vector<TreeLearner*>(num_class_, nullptr);
// create tree learner
tree_learner_ =
TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->tree_config);
// init tree learner
tree_learner_->Init(train_data_);
for (int i = 0; i < num_class_; ++i){
tree_learner_[i] =
TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->tree_config);
// init tree learner
tree_learner_[i]->Init(train_data_);
}
object_function_ = object_function;
// push training metrics
for (const auto& metric : training_metrics) {
training_metrics_.push_back(metric);
}
// create score tracker
train_score_updater_ = new ScoreUpdater(train_data_);
train_score_updater_ = new ScoreUpdater(train_data_, num_class_);
num_data_ = train_data_->num_data();
// create buffer for gradients and hessians
if (object_function_ != nullptr) {
gradients_ = new score_t[num_data_];
hessians_ = new score_t[num_data_];
gradients_ = new score_t[num_data_ * num_class_];
hessians_ = new score_t[num_data_ * num_class_];
}
// get max feature index
......@@ -85,7 +91,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
void GBDT::AddDataset(const Dataset* valid_data,
const std::vector<const Metric*>& valid_metrics) {
// for a validation dataset, we need its score and metric
valid_score_updater_.push_back(new ScoreUpdater(valid_data));
valid_score_updater_.push_back(new ScoreUpdater(valid_data, num_class_));
valid_metrics_.emplace_back();
best_iter_.emplace_back();
best_score_.emplace_back();
......@@ -97,7 +103,7 @@ void GBDT::AddDataset(const Dataset* valid_data,
}
void GBDT::Bagging(int iter) {
void GBDT::Bagging(int iter, const int curr_class) {
// if need bagging
if (out_of_bag_data_indices_ != nullptr && iter % gbdt_config_->bagging_freq == 0) {
// if doesn't have query data
......@@ -146,52 +152,59 @@ void GBDT::Bagging(int iter) {
}
Log::Info("re-bagging, using %d data to train", bag_data_cnt_);
// set bagging data to tree learner
tree_learner_->SetBaggingData(bag_data_indices_, bag_data_cnt_);
tree_learner_[curr_class]->SetBaggingData(bag_data_indices_, bag_data_cnt_);
}
}
void GBDT::UpdateScoreOutOfBag(const Tree* tree) {
void GBDT::UpdateScoreOutOfBag(const Tree* tree, const int curr_class) {
// we need to predict out-of-bag socres of data for boosting
if (out_of_bag_data_indices_ != nullptr) {
train_score_updater_->
AddScore(tree, out_of_bag_data_indices_, out_of_bag_data_cnt_);
AddScore(tree, out_of_bag_data_indices_, out_of_bag_data_cnt_, curr_class);
}
}
bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) {
// boosting first
if (gradient == nullptr || hessian == nullptr) {
Boosting();
gradient = gradients_;
hessian = hessians_;
}
// bagging logic
Bagging(iter_);
// train a new tree
Tree * new_tree = tree_learner_->Train(gradient, hessian);
// if cannot learn a new tree, then stop
if (new_tree->num_leaves() <= 1) {
Log::Info("Can't training anymore, there isn't any leaf meets split requirements.");
return true;
}
// shrinkage by learning rate
new_tree->Shrinkage(gbdt_config_->learning_rate);
// update score
UpdateScore(new_tree);
UpdateScoreOutOfBag(new_tree);
// boosting first
if (gradient == nullptr || hessian == nullptr) {
Boosting();
gradient = gradients_;
hessian = hessians_;
}
for (int curr_class = 0; curr_class < num_class_; ++curr_class){
// bagging logic
Bagging(iter_, curr_class);
// train a new tree
Tree * new_tree = tree_learner_[curr_class]->Train(gradient + curr_class * num_data_, hessian+ curr_class * num_data_);
// if cannot learn a new tree, then stop
if (new_tree->num_leaves() <= 1) {
Log::Info("Can't training anymore, there isn't any leaf meets split requirements.");
return true;
}
// shrinkage by learning rate
new_tree->Shrinkage(gbdt_config_->learning_rate);
// update score
UpdateScore(new_tree, curr_class);
UpdateScoreOutOfBag(new_tree, curr_class);
// add model
models_.push_back(new_tree);
}
bool is_met_early_stopping = false;
// print message for metric
if (is_eval) {
is_met_early_stopping = OutputMetric(iter_ + 1);
}
// add model
models_.push_back(new_tree);
++iter_;
if (is_met_early_stopping) {
Log::Info("Early stopping at iteration %d, the best iteration round is %d",
iter_, iter_ - early_stopping_round_);
// pop last early_stopping_round_ models
for (int i = 0; i < early_stopping_round_; ++i) {
for (int i = 0; i < early_stopping_round_ * num_class_; ++i) {
delete models_.back();
models_.pop_back();
}
......@@ -200,12 +213,12 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
}
void GBDT::UpdateScore(const Tree* tree) {
void GBDT::UpdateScore(const Tree* tree, const int curr_class) {
// update training score
train_score_updater_->AddScore(tree_learner_);
train_score_updater_->AddScore(tree_learner_[curr_class], curr_class);
// update validation score
for (auto& score_tracker : valid_score_updater_) {
score_tracker->AddScore(tree);
score_tracker->AddScore(tree, curr_class);
}
}
......@@ -298,6 +311,8 @@ void GBDT::SaveModelToFile(bool is_finish, const char* filename) {
model_output_file_.open(filename);
// output model type
model_output_file_ << "gbdt" << std::endl;
// output number of class
model_output_file_ << "num_class=" << num_class_ << std::endl;
// output label index
model_output_file_ << "label_index=" << label_idx_ << std::endl;
// output max_feature_idx
......@@ -311,7 +326,7 @@ void GBDT::SaveModelToFile(bool is_finish, const char* filename) {
if (!model_output_file_.is_open()) {
return;
}
int rest = static_cast<int>(models_.size()) - early_stopping_round_;
int rest = static_cast<int>(models_.size()) - early_stopping_round_ * num_class_;
// output tree models
for (int i = saved_model_size_; i < rest; ++i) {
model_output_file_ << "Tree=" << i << std::endl;
......@@ -337,8 +352,26 @@ void GBDT::ModelsFromString(const std::string& model_str) {
models_.clear();
std::vector<std::string> lines = Common::Split(model_str.c_str(), '\n');
size_t i = 0;
// get number of class
while (i < lines.size()) {
size_t find_pos = lines[i].find("num_class=");
if (find_pos != std::string::npos) {
std::vector<std::string> strs = Common::Split(lines[i].c_str(), '=');
Common::Atoi(strs[1].c_str(), &num_class_);
++i;
break;
} else {
++i;
}
}
if (i == lines.size()) {
Log::Fatal("Model file doesn't contain number of class");
return;
}
// get index of label
i = 0;
while (i < lines.size()) {
size_t find_pos = lines[i].find("label_index=");
if (find_pos != std::string::npos) {
......@@ -460,6 +493,20 @@ float GBDT::Predict(const float* value, int num_used_model) const {
return ret;
}
std::vector<float> GBDT::PredictMulticlass(const float* value, int num_used_model) const {
if (num_used_model < 0) {
num_used_model = static_cast<int>(models_.size()) / num_class_;
}
std::vector<float> ret(num_class_, 0.0f);
for (int i = 0; i < num_used_model; ++i) {
for (int j = 0; j < num_class_; ++j){
ret[j] += models_[i * num_class_ + j] -> Predict(value);
}
}
Common::Softmax(&ret);
return ret;
}
std::vector<int> GBDT::PredictLeafIndex(const float* value, int num_used_model) const {
if (num_used_model < 0) {
num_used_model = static_cast<int>(models_.size());
......
......@@ -68,13 +68,20 @@ public:
*/
float Predict(const float* feature_values, int num_used_model) const override;
/*!
* \brief Predtion for multiclass classification
* \param feature_values Feature value on this record
* \return Prediction result, num_class numbers per line
*/
std::vector<float> PredictMulticlass(const float* value, int num_used_model) const override;
/*!
* \brief Predtion for one record with leaf index
* \param feature_values Feature value on this record
* \param num_used_model Number of used model
* \return Predicted leaf index for this record
*/
std::vector<int> PredictLeafIndex(const float* value, int num_used_model) const override;
std::vector<int> PredictLeafIndex(const float* value, int num_used_model) const override;
/*!
* \brief Serialize models by string
......@@ -103,6 +110,12 @@ public:
*/
inline int NumberOfSubModels() const override { return static_cast<int>(models_.size()); }
/*!
* \brief Get number of classes
* \return Number of classes
*/
inline int NumberOfClass() const override { return num_class_; }
/*!
* \brief Get Type name of this boosting object
*/
......@@ -112,14 +125,16 @@ private:
/*!
* \brief Implement bagging logic
* \param iter Current interation
* \param curr_class Current class for multiclass training
*/
void Bagging(int iter);
void Bagging(int iter, const int curr_class);
/*!
* \brief updating score for out-of-bag data.
* Data should be update since we may re-bagging data on training
* \param tree Trained tree of this iteration
* \param curr_class Current class for multiclass training
*/
void UpdateScoreOutOfBag(const Tree* tree);
void UpdateScoreOutOfBag(const Tree* tree, const int curr_class);
/*!
* \brief calculate the object function
*/
......@@ -127,8 +142,9 @@ private:
/*!
* \brief updating score after tree was trained
* \param tree Trained tree of this iteration
* \param curr_class Current class for multiclass training
*/
void UpdateScore(const Tree* tree);
void UpdateScore(const Tree* tree, const int curr_class);
/*!
* \brief Print metric result of current iteration
* \param iter Current interation
......@@ -146,7 +162,7 @@ private:
/*! \brief Config of gbdt */
const GBDTConfig* gbdt_config_;
/*! \brief Tree learner, will use this class to learn trees */
TreeLearner* tree_learner_;
std::vector<TreeLearner*> tree_learner_;
/*! \brief Objective function */
const ObjectiveFunction* object_function_;
/*! \brief Store and update training data's score */
......@@ -180,6 +196,8 @@ private:
data_size_t bag_data_cnt_;
/*! \brief Number of traning data */
data_size_t num_data_;
/*! \brief Number of classes */
int num_class_;
/*! \brief Random generator, used for bagging */
Random random_;
/*!
......
......@@ -18,12 +18,12 @@ public:
* \brief Constructor, will pass a const pointer of dataset
* \param data This class will bind with this data set
*/
explicit ScoreUpdater(const Dataset* data)
explicit ScoreUpdater(const Dataset* data, int num_class)
:data_(data) {
num_data_ = data->num_data();
score_ = new score_t[num_data_];
score_ = new score_t[num_data_ * num_class];
// default start score is zero
std::memset(score_, 0, sizeof(score_t)*num_data_);
std::memset(score_, 0, sizeof(score_t) * num_data_ * num_class);
const score_t* init_score = data->metadata().init_score();
// if exists initial score, will start from it
if (init_score != nullptr) {
......@@ -41,8 +41,8 @@ public:
* Note: this function generally will be used on validation data too.
* \param tree Trained tree model
*/
inline void AddScore(const Tree* tree) {
tree->AddPredictionToScore(data_, num_data_, score_);
inline void AddScore(const Tree* tree, int curr_class) {
tree->AddPredictionToScore(data_, num_data_, score_ + curr_class * num_data_);
}
/*!
* \brief Adding prediction score, only used for training data.
......@@ -50,8 +50,8 @@ public:
* Based on which We can get prediction quckily.
* \param tree_learner
*/
inline void AddScore(const TreeLearner* tree_learner) {
tree_learner->AddPredictionToScore(score_);
inline void AddScore(const TreeLearner* tree_learner, int curr_class) {
tree_learner->AddPredictionToScore(score_ + curr_class * num_data_);
}
/*!
* \brief Using tree model to get prediction number, then adding to scores for parts of data
......@@ -61,8 +61,8 @@ public:
* \param data_cnt Number of data that will be proccessed
*/
inline void AddScore(const Tree* tree, const data_size_t* data_indices,
data_size_t data_cnt) {
tree->AddPredictionToScore(data_, data_indices, data_cnt, score_);
data_size_t data_cnt, int curr_class) {
tree->AddPredictionToScore(data_, data_indices, data_cnt, score_ + curr_class * num_data_);
}
/*! \brief Pointer of score */
inline const score_t * score() { return score_; }
......@@ -72,7 +72,7 @@ private:
data_size_t num_data_;
/*! \brief Pointer of data set */
const Dataset* data_;
/*! \brief scores for data set */
/*! \brief Scores for data set */
score_t* score_;
};
......
......@@ -46,7 +46,6 @@ void OverallConfig::Set(const std::unordered_map<std::string, std::string>& para
boosting_config = new GBDTConfig();
}
// sub-config setup
network_config.Set(params);
io_config.Set(params);
......@@ -132,7 +131,29 @@ void OverallConfig::GetTaskType(const std::unordered_map<std::string, std::strin
}
void OverallConfig::CheckParamConflict() {
GBDTConfig* gbdt_config = dynamic_cast<GBDTConfig*>(boosting_config);
GBDTConfig* gbdt_config = dynamic_cast<GBDTConfig*>(boosting_config);
// check if objective_type, metric_type, and num_class match
bool objective_type_multiclass = (objective_type == std::string("multiclass"));
int num_class_check = gbdt_config->num_class;
if (objective_type_multiclass){
if (num_class_check <= 1){
Log::Fatal("You should specify number of class(>=2) for multiclass training.");
}
}
else {
if (task_type == TaskType::kTrain && num_class_check != 1){
Log::Fatal("Number of class must be 1 for non-multiclass training.");
}
}
for (std::string metric_type : metric_types){
bool metric_type_multiclass = ( metric_type == std::string("multi_logloss") || metric_type == std::string("multi_error"));
if ((objective_type_multiclass && !metric_type_multiclass)
|| (!objective_type_multiclass && metric_type_multiclass)){
Log::Fatal("Objective and metrics don't match.");
}
}
if (network_config.num_machines > 1) {
is_parallel = true;
} else {
......@@ -196,6 +217,8 @@ void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& pa
GetFloat(params, "sigmoid", &sigmoid);
GetInt(params, "max_position", &max_position);
CHECK(max_position > 0);
GetInt(params, "num_class", &num_class);
CHECK(num_class >= 1);
std::string tmp_str = "";
if (GetString(params, "label_gain", &tmp_str)) {
label_gain = Common::StringToFloatArray(tmp_str, ',');
......@@ -212,6 +235,8 @@ void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& pa
void MetricConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetFloat(params, "sigmoid", &sigmoid);
GetInt(params, "num_class", &num_class);
CHECK(num_class >= 1);
std::string tmp_str = "";
if (GetString(params, "label_gain", &tmp_str)) {
label_gain = Common::StringToFloatArray(tmp_str, ',');
......@@ -268,6 +293,8 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par
GetInt(params, "metric_freq", &output_freq);
CHECK(output_freq >= 0);
GetBool(params, "is_training_metric", &is_provide_training_metric);
GetInt(params, "num_class", &num_class);
CHECK(num_class >= 1);
}
void GBDTConfig::GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params) {
......
......@@ -2,6 +2,7 @@
#include "regression_metric.hpp"
#include "binary_metric.hpp"
#include "rank_metric.hpp"
#include "multiclass_metric.hpp"
namespace LightGBM {
......@@ -18,6 +19,10 @@ Metric* Metric::CreateMetric(const std::string& type, const MetricConfig& config
return new AUCMetric(config);
} else if (type == "ndcg") {
return new NDCGMetric(config);
} else if (type == "multi_logloss"){
return new MultiLoglossMetric(config);
} else if (type == "multi_error"){
return new MultiErrorMetric(config);
}
return nullptr;
}
......
#ifndef LIGHTGBM_METRIC_MULTICLASS_METRIC_HPP_
#define LIGHTGBM_METRIC_MULTICLASS_METRIC_HPP_
#include <LightGBM/utils/log.h>
#include <LightGBM/metric.h>
#include <cmath>
namespace LightGBM {
/*!
* \brief Metric for multiclass task.
* Use static class "PointWiseLossCalculator" to calculate loss point-wise
*/
template<typename PointWiseLossCalculator>
class MulticlassMetric: public Metric {
public:
explicit MulticlassMetric(const MetricConfig& config) {
num_class_ = config.num_class;
}
virtual ~MulticlassMetric() {
}
void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override {
std::stringstream str_buf;
str_buf << test_name << "'s " << PointWiseLossCalculator::Name();
name_ = str_buf.str();
num_data_ = num_data;
// get label
label_ = metadata.label();
// get weights
weights_ = metadata.weights();
if (weights_ == nullptr) {
sum_weights_ = static_cast<float>(num_data_);
} else {
sum_weights_ = 0.0f;
for (data_size_t i = 0; i < num_data_; ++i) {
sum_weights_ += weights_[i];
}
}
}
const char* GetName() const override {
return name_.c_str();
}
bool is_bigger_better() const override {
return false;
}
std::vector<score_t> Eval(const score_t* score) const override {
score_t sum_loss = 0.0;
if (weights_ == nullptr) {
#pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) {
std::vector<score_t> rec(num_class_);
for (int k = 0; k < num_class_; ++k) {
rec[k] = score[k * num_data_ + i];
}
// add loss
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec);
}
} else {
#pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) {
std::vector<score_t> rec(num_class_);
for (int k = 0; k < num_class_; ++k) {
rec[k] = score[k * num_data_ + i];
}
// add loss
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec) * weights_[i];
}
}
score_t loss = sum_loss / sum_weights_;
return std::vector<score_t>(1, loss);
}
private:
/*! \brief Output frequency */
int output_freq_;
/*! \brief Number of data */
data_size_t num_data_;
/*! \brief Number of classes */
int num_class_;
/*! \brief Pointer of label */
const float* label_;
/*! \brief Pointer of weighs */
const float* weights_;
/*! \brief Sum weights */
float sum_weights_;
/*! \brief Name of this test set */
std::string name_;
};
/*! \brief L2 loss for multiclass task */
class MultiErrorMetric: public MulticlassMetric<MultiErrorMetric> {
public:
explicit MultiErrorMetric(const MetricConfig& config) :MulticlassMetric<MultiErrorMetric>(config) {}
inline static score_t LossOnPoint(float label, std::vector<score_t> score) {
size_t k = static_cast<size_t>(label);
for (size_t i = 0; i < score.size(); ++i){
if (i != k && score[i] > score[k]) {
return 0.0f;
}
}
return 1.0f;
}
inline static const char* Name() {
return "multi error";
}
};
/*! \brief Logloss for multiclass task */
class MultiLoglossMetric: public MulticlassMetric<MultiLoglossMetric> {
public:
explicit MultiLoglossMetric(const MetricConfig& config) :MulticlassMetric<MultiLoglossMetric>(config) {}
inline static score_t LossOnPoint(float label, std::vector<score_t> score) {
size_t k = static_cast<size_t>(label);
Common::Softmax(&score);
if (score[k] > kEpsilon) {
return -std::log(score[k]);
} else {
return -std::log(kEpsilon);
}
}
inline static const char* Name() {
return "multi logloss";
}
};
} // namespace LightGBM
#endif // LightGBM_METRIC_MULTICLASS_METRIC_HPP_
......@@ -8,7 +8,7 @@
namespace LightGBM {
/*!
* \brief Objective funtion for binary classification
* \brief Objective function for binary classification
*/
class BinaryLogloss: public ObjectiveFunction {
public:
......
#ifndef LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_
#define LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_
#include <LightGBM/objective_function.h>
#include <cstring>
#include <cmath>
namespace LightGBM {
/*!
* \brief Objective function for multiclass classification
*/
class MulticlassLogloss: public ObjectiveFunction {
public:
explicit MulticlassLogloss(const ObjectiveConfig& config)
:label_int_(nullptr) {
num_class_ = config.num_class;
}
~MulticlassLogloss() {
if (label_int_ != nullptr) { delete[] label_int_; }
}
void Init(const Metadata& metadata, data_size_t num_data) override {
num_data_ = num_data;
label_ = metadata.label();
weights_ = metadata.weights();
label_int_ = new int[num_data_];
for (int i = 0; i < num_data_; ++i){
label_int_[i] = static_cast<int>(label_[i]);
if (label_int_[i] < 0 || label_int_[i] >= num_class_) {
Log::Fatal("Label must be in [0, %d), but find %d in label", num_class_, label_int_[i]);
}
}
}
void GetGradients(const score_t* score, score_t* gradients, score_t* hessians) const override {
if (weights_ == nullptr) {
#pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data_; ++i) {
std::vector<score_t> rec(num_class_);
for (int k = 0; k < num_class_; ++k){
rec[k] = score[k * num_data_ + i];
}
Common::Softmax(&rec);
for (int k = 0; k < num_class_; ++k) {
score_t p = rec[k];
if (label_int_[i] == k) {
gradients[k * num_data_ + i] = p - 1.0f;
} else {
gradients[k * num_data_ + i] = p;
}
hessians[k * num_data_ + i] = 2.0f * p * (1.0f - p);
}
}
} else {
#pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data_; ++i) {
std::vector<score_t> rec(num_class_);
for (int k = 0; k < num_class_; ++k){
rec[k] = score[k * num_data_ + i];
}
Common::Softmax(&rec);
for (int k = 0; k < num_class_; ++k) {
float p = rec[k];
if (label_int_[i] == k) {
gradients[k * num_data_ + i] = (p - 1.0f) * weights_[i];
} else {
gradients[k * num_data_ + i] = p * weights_[i];
}
hessians[k * num_data_ + i] = 2.0f * p * (1.0f - p) * weights_[i];
}
}
}
}
float GetSigmoid() const override {
return -1.0f;
}
private:
/*! \brief Number of data */
data_size_t num_data_;
/*! \brief Number of classes */
int num_class_;
/*! \brief Pointer of label */
const float* label_;
/*! \brief Corresponding integers of label_ */
int* label_int_;
/*! \brief Weights for data */
const float* weights_;
};
} // namespace LightGBM
#endif // LightGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_
......@@ -2,6 +2,7 @@
#include "regression_objective.hpp"
#include "binary_objective.hpp"
#include "rank_objective.hpp"
#include "multiclass_objective.hpp"
namespace LightGBM {
......@@ -12,6 +13,8 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string&
return new BinaryLogloss(config);
} else if (type == "lambdarank") {
return new LambdarankNDCG(config);
} else if (type == "multiclass") {
return new MulticlassLogloss(config);
}
return nullptr;
}
......
......@@ -14,7 +14,7 @@
namespace LightGBM {
/*!
* \brief Objective funtion for Lambdrank with NDCG
* \brief Objective function for Lambdrank with NDCG
*/
class LambdarankNDCG: public ObjectiveFunction {
public:
......
......@@ -5,7 +5,7 @@
namespace LightGBM {
/*!
* \brief Objective funtion for regression
* \brief Objective function for regression
*/
class RegressionL2loss: public ObjectiveFunction {
public:
......
......@@ -185,11 +185,13 @@
<ClInclude Include="..\src\metric\binary_metric.hpp" />
<ClInclude Include="..\src\metric\rank_metric.hpp" />
<ClInclude Include="..\src\metric\regression_metric.hpp" />
<ClInclude Include="..\src\metric\multiclass_metric.hpp" />
<ClInclude Include="..\src\network\linkers.h" />
<ClInclude Include="..\src\network\socket_wrapper.hpp" />
<ClInclude Include="..\src\objective\binary_objective.hpp" />
<ClInclude Include="..\src\objective\rank_objective.hpp" />
<ClInclude Include="..\src\objective\regression_objective.hpp" />
<ClInclude Include="..\src\objective\multiclass_objective.hpp" />
<ClInclude Include="..\src\treelearner\data_partition.hpp" />
<ClInclude Include="..\src\treelearner\feature_histogram.hpp" />
<ClInclude Include="..\src\treelearner\leaf_splits.hpp" />
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment