"git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "b09da434f0af15499dc4ce6ef498e524c81303d2"
Commit aa796a85 authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

Support multiclass classification (#53)

Support multiclass classification (#53)
parent 90ffe1c9
This diff is collapsed.
This diff is collapsed.
task = predict
data = multiclass.test
input_model= LightGBM_model.txt
# task type, support train and predict
task = train
# boosting type, support gbdt for now, alias: boosting, boost
boosting_type = gbdt
# application type, support following application
# regression , regression task
# binary , binary classification task
# lambdarank , lambdarank task
# multiclass
# alias: application, app
objective = multiclass
# eval metrics, support multi metric, delimite by ',' , support following metrics
# l1
# l2 , default metric for regression
# ndcg , default metric for lambdarank
# auc
# binary_logloss , default metric for binary
# binary_error
# multi_logloss
# multi_error
metric = multi_logloss
# number of class, for multiclass classification
num_class = 5
# frequence for metric output
metric_freq = 1
# true if need output metric for training data, alias: tranining_metric, train_metric
is_training_metric = true
# number of bins for feature bucket, 255 is a recommend setting, it can save memories, and also has good accuracy.
max_bin = 255
# training data
# if exsting weight file, should name to "regression.train.weight"
# alias: train_data, train
data = multiclass.train
# valid data
valid_data = multiclass.test
# round for early stopping
early_stopping = 10
# number of trees(iterations), alias: num_tree, num_iteration, num_iterations, num_round, num_rounds
num_trees = 100
# shrinkage rate , alias: shrinkage_rate
learning_rate = 0.05
# number of leaves for one tree, alias: num_leaf
num_leaves = 31
...@@ -80,6 +80,13 @@ public: ...@@ -80,6 +80,13 @@ public:
const float* feature_values, const float* feature_values,
int num_used_model) const = 0; int num_used_model) const = 0;
/*!
* \brief Predtion for multiclass classification
* \param feature_values Feature value on this record
* \return Prediction result, num_class numbers per line
*/
virtual std::vector<float> PredictMulticlass(const float* value, int num_used_model) const = 0;
/*! /*!
* \brief save model to file * \brief save model to file
*/ */
...@@ -108,7 +115,13 @@ public: ...@@ -108,7 +115,13 @@ public:
* \return Number of weak sub-models * \return Number of weak sub-models
*/ */
virtual int NumberOfSubModels() const = 0; virtual int NumberOfSubModels() const = 0;
/*!
* \brief Get number of classes
* \return Number of classes
*/
virtual int NumberOfClass() const = 0;
/*! /*!
* \brief Get Type name of this boosting object * \brief Get Type name of this boosting object
*/ */
......
...@@ -128,6 +128,8 @@ public: ...@@ -128,6 +128,8 @@ public:
int max_position = 20; int max_position = 20;
// for binary // for binary
bool is_unbalance = false; bool is_unbalance = false;
// for multiclass
int num_class = 1;
void Set(const std::unordered_map<std::string, std::string>& params) override; void Set(const std::unordered_map<std::string, std::string>& params) override;
}; };
...@@ -135,6 +137,7 @@ public: ...@@ -135,6 +137,7 @@ public:
struct MetricConfig: public ConfigBase { struct MetricConfig: public ConfigBase {
public: public:
virtual ~MetricConfig() {} virtual ~MetricConfig() {}
int num_class = 1;
float sigmoid = 1.0f; float sigmoid = 1.0f;
std::vector<float> label_gain; std::vector<float> label_gain;
std::vector<int> eval_at; std::vector<int> eval_at;
...@@ -179,6 +182,7 @@ public: ...@@ -179,6 +182,7 @@ public:
int bagging_seed = 3; int bagging_seed = 3;
int bagging_freq = 0; int bagging_freq = 0;
int early_stopping_round = 0; int early_stopping_round = 0;
int num_class = 1;
void Set(const std::unordered_map<std::string, std::string>& params) override; void Set(const std::unordered_map<std::string, std::string>& params) override;
}; };
......
...@@ -336,6 +336,26 @@ static inline int64_t Pow2RoundUp(int64_t x) { ...@@ -336,6 +336,26 @@ static inline int64_t Pow2RoundUp(int64_t x) {
return 0; return 0;
} }
/*!
* \brief Do inplace softmax transformaton on p_rec
* \param p_rec The input/output vector of the values.
*/
inline void Softmax(std::vector<float>* p_rec) {
std::vector<float> &rec = *p_rec;
float wmax = rec[0];
for (size_t i = 1; i < rec.size(); ++i) {
wmax = std::max(rec[i], wmax);
}
float wsum = 0.0f;
for (size_t i = 0; i < rec.size(); ++i) {
rec[i] = std::exp(rec[i] - wmax);
wsum += rec[i];
}
for (size_t i = 0; i < rec.size(); ++i) {
rec[i] /= static_cast<float>(wsum);
}
}
} // namespace Common } // namespace Common
} // namespace LightGBM } // namespace LightGBM
......
...@@ -33,6 +33,7 @@ public: ...@@ -33,6 +33,7 @@ public:
num_used_model_(num_used_model) { num_used_model_(num_used_model) {
boosting_ = boosting; boosting_ = boosting;
num_features_ = boosting_->MaxFeatureIdx() + 1; num_features_ = boosting_->MaxFeatureIdx() + 1;
num_class_ = boosting_->NumberOfClass();
#pragma omp parallel #pragma omp parallel
#pragma omp master #pragma omp master
{ {
...@@ -87,6 +88,18 @@ public: ...@@ -87,6 +88,18 @@ public:
// get result with sigmoid transform if needed // get result with sigmoid transform if needed
return boosting_->Predict(features_[tid], num_used_model_); return boosting_->Predict(features_[tid], num_used_model_);
} }
/*!
* \brief prediction for multiclass classification
* \param features Feature of this record
* \return Prediction result
*/
std::vector<float> PredictMulticlassOneLine(const std::vector<std::pair<int, float>>& features) {
const int tid = PutFeatureValuesToBuffer(features);
// get result with sigmoid transform if needed
return boosting_->PredictMulticlass(features_[tid], num_used_model_);
}
/*! /*!
* \brief predicting on data, then saving result to disk * \brief predicting on data, then saving result to disk
* \param data_filename Filename of data * \param data_filename Filename of data
...@@ -120,17 +133,30 @@ public: ...@@ -120,17 +133,30 @@ public:
}; };
std::function<std::string(const std::vector<std::pair<int, float>>&)> predict_fun; std::function<std::string(const std::vector<std::pair<int, float>>&)> predict_fun;
if (is_predict_leaf_index_) { if (num_class_ > 1) {
predict_fun = [this](const std::vector<std::pair<int, float>>& features){
std::vector<float> prediction = PredictMulticlassOneLine(features);
std::stringstream result_stream_buf;
for (size_t i = 0; i < prediction.size(); ++i){
if (i > 0) {
result_stream_buf << '\t';
}
result_stream_buf << prediction[i];
}
return result_stream_buf.str();
};
}
else if (is_predict_leaf_index_) {
predict_fun = [this](const std::vector<std::pair<int, float>>& features){ predict_fun = [this](const std::vector<std::pair<int, float>>& features){
std::vector<int> predicted_leaf_index = PredictLeafIndexOneLine(features); std::vector<int> predicted_leaf_index = PredictLeafIndexOneLine(features);
std::stringstream result_ss; std::stringstream result_stream_buf;
for (size_t i = 0; i < predicted_leaf_index.size(); ++i){ for (size_t i = 0; i < predicted_leaf_index.size(); ++i){
if (i > 0) { if (i > 0) {
result_ss << '\t'; result_stream_buf << '\t';
} }
result_ss << predicted_leaf_index[i]; result_stream_buf << predicted_leaf_index[i];
} }
return result_ss.str(); return result_stream_buf.str();
}; };
} }
else { else {
...@@ -189,6 +215,8 @@ private: ...@@ -189,6 +215,8 @@ private:
float** features_; float** features_;
/*! \brief Number of features */ /*! \brief Number of features */
int num_features_; int num_features_;
/*! \brief Number of classes */
int num_class_;
/*! \brief True if need to predict result with sigmoid transform */ /*! \brief True if need to predict result with sigmoid transform */
bool is_simgoid_; bool is_simgoid_;
/*! \brief Number of threads */ /*! \brief Number of threads */
......
...@@ -17,13 +17,15 @@ ...@@ -17,13 +17,15 @@
namespace LightGBM { namespace LightGBM {
GBDT::GBDT() GBDT::GBDT()
: tree_learner_(nullptr), train_score_updater_(nullptr), : train_score_updater_(nullptr),
gradients_(nullptr), hessians_(nullptr), gradients_(nullptr), hessians_(nullptr),
out_of_bag_data_indices_(nullptr), bag_data_indices_(nullptr) { out_of_bag_data_indices_(nullptr), bag_data_indices_(nullptr) {
} }
GBDT::~GBDT() { GBDT::~GBDT() {
if (tree_learner_ != nullptr) { delete tree_learner_; } for (auto& tree_learner: tree_learner_){
if (tree_learner != nullptr) { delete tree_learner; }
}
if (gradients_ != nullptr) { delete[] gradients_; } if (gradients_ != nullptr) { delete[] gradients_; }
if (hessians_ != nullptr) { delete[] hessians_; } if (hessians_ != nullptr) { delete[] hessians_; }
if (out_of_bag_data_indices_ != nullptr) { delete[] out_of_bag_data_indices_; } if (out_of_bag_data_indices_ != nullptr) { delete[] out_of_bag_data_indices_; }
...@@ -44,23 +46,27 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O ...@@ -44,23 +46,27 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
max_feature_idx_ = 0; max_feature_idx_ = 0;
early_stopping_round_ = gbdt_config_->early_stopping_round; early_stopping_round_ = gbdt_config_->early_stopping_round;
train_data_ = train_data; train_data_ = train_data;
num_class_ = config->num_class;
tree_learner_ = std::vector<TreeLearner*>(num_class_, nullptr);
// create tree learner // create tree learner
tree_learner_ = for (int i = 0; i < num_class_; ++i){
TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->tree_config); tree_learner_[i] =
// init tree learner TreeLearner::CreateTreeLearner(gbdt_config_->tree_learner_type, gbdt_config_->tree_config);
tree_learner_->Init(train_data_); // init tree learner
tree_learner_[i]->Init(train_data_);
}
object_function_ = object_function; object_function_ = object_function;
// push training metrics // push training metrics
for (const auto& metric : training_metrics) { for (const auto& metric : training_metrics) {
training_metrics_.push_back(metric); training_metrics_.push_back(metric);
} }
// create score tracker // create score tracker
train_score_updater_ = new ScoreUpdater(train_data_); train_score_updater_ = new ScoreUpdater(train_data_, num_class_);
num_data_ = train_data_->num_data(); num_data_ = train_data_->num_data();
// create buffer for gradients and hessians // create buffer for gradients and hessians
if (object_function_ != nullptr) { if (object_function_ != nullptr) {
gradients_ = new score_t[num_data_]; gradients_ = new score_t[num_data_ * num_class_];
hessians_ = new score_t[num_data_]; hessians_ = new score_t[num_data_ * num_class_];
} }
// get max feature index // get max feature index
...@@ -85,7 +91,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O ...@@ -85,7 +91,7 @@ void GBDT::Init(const BoostingConfig* config, const Dataset* train_data, const O
void GBDT::AddDataset(const Dataset* valid_data, void GBDT::AddDataset(const Dataset* valid_data,
const std::vector<const Metric*>& valid_metrics) { const std::vector<const Metric*>& valid_metrics) {
// for a validation dataset, we need its score and metric // for a validation dataset, we need its score and metric
valid_score_updater_.push_back(new ScoreUpdater(valid_data)); valid_score_updater_.push_back(new ScoreUpdater(valid_data, num_class_));
valid_metrics_.emplace_back(); valid_metrics_.emplace_back();
best_iter_.emplace_back(); best_iter_.emplace_back();
best_score_.emplace_back(); best_score_.emplace_back();
...@@ -97,7 +103,7 @@ void GBDT::AddDataset(const Dataset* valid_data, ...@@ -97,7 +103,7 @@ void GBDT::AddDataset(const Dataset* valid_data,
} }
void GBDT::Bagging(int iter) { void GBDT::Bagging(int iter, const int curr_class) {
// if need bagging // if need bagging
if (out_of_bag_data_indices_ != nullptr && iter % gbdt_config_->bagging_freq == 0) { if (out_of_bag_data_indices_ != nullptr && iter % gbdt_config_->bagging_freq == 0) {
// if doesn't have query data // if doesn't have query data
...@@ -146,52 +152,59 @@ void GBDT::Bagging(int iter) { ...@@ -146,52 +152,59 @@ void GBDT::Bagging(int iter) {
} }
Log::Info("re-bagging, using %d data to train", bag_data_cnt_); Log::Info("re-bagging, using %d data to train", bag_data_cnt_);
// set bagging data to tree learner // set bagging data to tree learner
tree_learner_->SetBaggingData(bag_data_indices_, bag_data_cnt_); tree_learner_[curr_class]->SetBaggingData(bag_data_indices_, bag_data_cnt_);
} }
} }
void GBDT::UpdateScoreOutOfBag(const Tree* tree) { void GBDT::UpdateScoreOutOfBag(const Tree* tree, const int curr_class) {
// we need to predict out-of-bag socres of data for boosting // we need to predict out-of-bag socres of data for boosting
if (out_of_bag_data_indices_ != nullptr) { if (out_of_bag_data_indices_ != nullptr) {
train_score_updater_-> train_score_updater_->
AddScore(tree, out_of_bag_data_indices_, out_of_bag_data_cnt_); AddScore(tree, out_of_bag_data_indices_, out_of_bag_data_cnt_, curr_class);
} }
} }
bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) { bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is_eval) {
// boosting first // boosting first
if (gradient == nullptr || hessian == nullptr) { if (gradient == nullptr || hessian == nullptr) {
Boosting(); Boosting();
gradient = gradients_; gradient = gradients_;
hessian = hessians_; hessian = hessians_;
} }
// bagging logic
Bagging(iter_); for (int curr_class = 0; curr_class < num_class_; ++curr_class){
// train a new tree // bagging logic
Tree * new_tree = tree_learner_->Train(gradient, hessian); Bagging(iter_, curr_class);
// if cannot learn a new tree, then stop
if (new_tree->num_leaves() <= 1) { // train a new tree
Log::Info("Can't training anymore, there isn't any leaf meets split requirements."); Tree * new_tree = tree_learner_[curr_class]->Train(gradient + curr_class * num_data_, hessian+ curr_class * num_data_);
return true; // if cannot learn a new tree, then stop
} if (new_tree->num_leaves() <= 1) {
// shrinkage by learning rate Log::Info("Can't training anymore, there isn't any leaf meets split requirements.");
new_tree->Shrinkage(gbdt_config_->learning_rate); return true;
// update score }
UpdateScore(new_tree);
UpdateScoreOutOfBag(new_tree); // shrinkage by learning rate
new_tree->Shrinkage(gbdt_config_->learning_rate);
// update score
UpdateScore(new_tree, curr_class);
UpdateScoreOutOfBag(new_tree, curr_class);
// add model
models_.push_back(new_tree);
}
bool is_met_early_stopping = false; bool is_met_early_stopping = false;
// print message for metric // print message for metric
if (is_eval) { if (is_eval) {
is_met_early_stopping = OutputMetric(iter_ + 1); is_met_early_stopping = OutputMetric(iter_ + 1);
} }
// add model
models_.push_back(new_tree);
++iter_; ++iter_;
if (is_met_early_stopping) { if (is_met_early_stopping) {
Log::Info("Early stopping at iteration %d, the best iteration round is %d", Log::Info("Early stopping at iteration %d, the best iteration round is %d",
iter_, iter_ - early_stopping_round_); iter_, iter_ - early_stopping_round_);
// pop last early_stopping_round_ models // pop last early_stopping_round_ models
for (int i = 0; i < early_stopping_round_; ++i) { for (int i = 0; i < early_stopping_round_ * num_class_; ++i) {
delete models_.back(); delete models_.back();
models_.pop_back(); models_.pop_back();
} }
...@@ -200,12 +213,12 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is ...@@ -200,12 +213,12 @@ bool GBDT::TrainOneIter(const score_t* gradient, const score_t* hessian, bool is
} }
void GBDT::UpdateScore(const Tree* tree) { void GBDT::UpdateScore(const Tree* tree, const int curr_class) {
// update training score // update training score
train_score_updater_->AddScore(tree_learner_); train_score_updater_->AddScore(tree_learner_[curr_class], curr_class);
// update validation score // update validation score
for (auto& score_tracker : valid_score_updater_) { for (auto& score_tracker : valid_score_updater_) {
score_tracker->AddScore(tree); score_tracker->AddScore(tree, curr_class);
} }
} }
...@@ -298,6 +311,8 @@ void GBDT::SaveModelToFile(bool is_finish, const char* filename) { ...@@ -298,6 +311,8 @@ void GBDT::SaveModelToFile(bool is_finish, const char* filename) {
model_output_file_.open(filename); model_output_file_.open(filename);
// output model type // output model type
model_output_file_ << "gbdt" << std::endl; model_output_file_ << "gbdt" << std::endl;
// output number of class
model_output_file_ << "num_class=" << num_class_ << std::endl;
// output label index // output label index
model_output_file_ << "label_index=" << label_idx_ << std::endl; model_output_file_ << "label_index=" << label_idx_ << std::endl;
// output max_feature_idx // output max_feature_idx
...@@ -311,7 +326,7 @@ void GBDT::SaveModelToFile(bool is_finish, const char* filename) { ...@@ -311,7 +326,7 @@ void GBDT::SaveModelToFile(bool is_finish, const char* filename) {
if (!model_output_file_.is_open()) { if (!model_output_file_.is_open()) {
return; return;
} }
int rest = static_cast<int>(models_.size()) - early_stopping_round_; int rest = static_cast<int>(models_.size()) - early_stopping_round_ * num_class_;
// output tree models // output tree models
for (int i = saved_model_size_; i < rest; ++i) { for (int i = saved_model_size_; i < rest; ++i) {
model_output_file_ << "Tree=" << i << std::endl; model_output_file_ << "Tree=" << i << std::endl;
...@@ -337,8 +352,26 @@ void GBDT::ModelsFromString(const std::string& model_str) { ...@@ -337,8 +352,26 @@ void GBDT::ModelsFromString(const std::string& model_str) {
models_.clear(); models_.clear();
std::vector<std::string> lines = Common::Split(model_str.c_str(), '\n'); std::vector<std::string> lines = Common::Split(model_str.c_str(), '\n');
size_t i = 0; size_t i = 0;
// get number of class
while (i < lines.size()) {
size_t find_pos = lines[i].find("num_class=");
if (find_pos != std::string::npos) {
std::vector<std::string> strs = Common::Split(lines[i].c_str(), '=');
Common::Atoi(strs[1].c_str(), &num_class_);
++i;
break;
} else {
++i;
}
}
if (i == lines.size()) {
Log::Fatal("Model file doesn't contain number of class");
return;
}
// get index of label // get index of label
i = 0;
while (i < lines.size()) { while (i < lines.size()) {
size_t find_pos = lines[i].find("label_index="); size_t find_pos = lines[i].find("label_index=");
if (find_pos != std::string::npos) { if (find_pos != std::string::npos) {
...@@ -460,6 +493,20 @@ float GBDT::Predict(const float* value, int num_used_model) const { ...@@ -460,6 +493,20 @@ float GBDT::Predict(const float* value, int num_used_model) const {
return ret; return ret;
} }
std::vector<float> GBDT::PredictMulticlass(const float* value, int num_used_model) const {
if (num_used_model < 0) {
num_used_model = static_cast<int>(models_.size()) / num_class_;
}
std::vector<float> ret(num_class_, 0.0f);
for (int i = 0; i < num_used_model; ++i) {
for (int j = 0; j < num_class_; ++j){
ret[j] += models_[i * num_class_ + j] -> Predict(value);
}
}
Common::Softmax(&ret);
return ret;
}
std::vector<int> GBDT::PredictLeafIndex(const float* value, int num_used_model) const { std::vector<int> GBDT::PredictLeafIndex(const float* value, int num_used_model) const {
if (num_used_model < 0) { if (num_used_model < 0) {
num_used_model = static_cast<int>(models_.size()); num_used_model = static_cast<int>(models_.size());
......
...@@ -68,13 +68,20 @@ public: ...@@ -68,13 +68,20 @@ public:
*/ */
float Predict(const float* feature_values, int num_used_model) const override; float Predict(const float* feature_values, int num_used_model) const override;
/*!
* \brief Predtion for multiclass classification
* \param feature_values Feature value on this record
* \return Prediction result, num_class numbers per line
*/
std::vector<float> PredictMulticlass(const float* value, int num_used_model) const override;
/*! /*!
* \brief Predtion for one record with leaf index * \brief Predtion for one record with leaf index
* \param feature_values Feature value on this record * \param feature_values Feature value on this record
* \param num_used_model Number of used model * \param num_used_model Number of used model
* \return Predicted leaf index for this record * \return Predicted leaf index for this record
*/ */
std::vector<int> PredictLeafIndex(const float* value, int num_used_model) const override; std::vector<int> PredictLeafIndex(const float* value, int num_used_model) const override;
/*! /*!
* \brief Serialize models by string * \brief Serialize models by string
...@@ -103,6 +110,12 @@ public: ...@@ -103,6 +110,12 @@ public:
*/ */
inline int NumberOfSubModels() const override { return static_cast<int>(models_.size()); } inline int NumberOfSubModels() const override { return static_cast<int>(models_.size()); }
/*!
* \brief Get number of classes
* \return Number of classes
*/
inline int NumberOfClass() const override { return num_class_; }
/*! /*!
* \brief Get Type name of this boosting object * \brief Get Type name of this boosting object
*/ */
...@@ -112,14 +125,16 @@ private: ...@@ -112,14 +125,16 @@ private:
/*! /*!
* \brief Implement bagging logic * \brief Implement bagging logic
* \param iter Current interation * \param iter Current interation
* \param curr_class Current class for multiclass training
*/ */
void Bagging(int iter); void Bagging(int iter, const int curr_class);
/*! /*!
* \brief updating score for out-of-bag data. * \brief updating score for out-of-bag data.
* Data should be update since we may re-bagging data on training * Data should be update since we may re-bagging data on training
* \param tree Trained tree of this iteration * \param tree Trained tree of this iteration
* \param curr_class Current class for multiclass training
*/ */
void UpdateScoreOutOfBag(const Tree* tree); void UpdateScoreOutOfBag(const Tree* tree, const int curr_class);
/*! /*!
* \brief calculate the object function * \brief calculate the object function
*/ */
...@@ -127,8 +142,9 @@ private: ...@@ -127,8 +142,9 @@ private:
/*! /*!
* \brief updating score after tree was trained * \brief updating score after tree was trained
* \param tree Trained tree of this iteration * \param tree Trained tree of this iteration
* \param curr_class Current class for multiclass training
*/ */
void UpdateScore(const Tree* tree); void UpdateScore(const Tree* tree, const int curr_class);
/*! /*!
* \brief Print metric result of current iteration * \brief Print metric result of current iteration
* \param iter Current interation * \param iter Current interation
...@@ -146,7 +162,7 @@ private: ...@@ -146,7 +162,7 @@ private:
/*! \brief Config of gbdt */ /*! \brief Config of gbdt */
const GBDTConfig* gbdt_config_; const GBDTConfig* gbdt_config_;
/*! \brief Tree learner, will use this class to learn trees */ /*! \brief Tree learner, will use this class to learn trees */
TreeLearner* tree_learner_; std::vector<TreeLearner*> tree_learner_;
/*! \brief Objective function */ /*! \brief Objective function */
const ObjectiveFunction* object_function_; const ObjectiveFunction* object_function_;
/*! \brief Store and update training data's score */ /*! \brief Store and update training data's score */
...@@ -180,6 +196,8 @@ private: ...@@ -180,6 +196,8 @@ private:
data_size_t bag_data_cnt_; data_size_t bag_data_cnt_;
/*! \brief Number of traning data */ /*! \brief Number of traning data */
data_size_t num_data_; data_size_t num_data_;
/*! \brief Number of classes */
int num_class_;
/*! \brief Random generator, used for bagging */ /*! \brief Random generator, used for bagging */
Random random_; Random random_;
/*! /*!
......
...@@ -18,12 +18,12 @@ public: ...@@ -18,12 +18,12 @@ public:
* \brief Constructor, will pass a const pointer of dataset * \brief Constructor, will pass a const pointer of dataset
* \param data This class will bind with this data set * \param data This class will bind with this data set
*/ */
explicit ScoreUpdater(const Dataset* data) explicit ScoreUpdater(const Dataset* data, int num_class)
:data_(data) { :data_(data) {
num_data_ = data->num_data(); num_data_ = data->num_data();
score_ = new score_t[num_data_]; score_ = new score_t[num_data_ * num_class];
// default start score is zero // default start score is zero
std::memset(score_, 0, sizeof(score_t)*num_data_); std::memset(score_, 0, sizeof(score_t) * num_data_ * num_class);
const score_t* init_score = data->metadata().init_score(); const score_t* init_score = data->metadata().init_score();
// if exists initial score, will start from it // if exists initial score, will start from it
if (init_score != nullptr) { if (init_score != nullptr) {
...@@ -41,8 +41,8 @@ public: ...@@ -41,8 +41,8 @@ public:
* Note: this function generally will be used on validation data too. * Note: this function generally will be used on validation data too.
* \param tree Trained tree model * \param tree Trained tree model
*/ */
inline void AddScore(const Tree* tree) { inline void AddScore(const Tree* tree, int curr_class) {
tree->AddPredictionToScore(data_, num_data_, score_); tree->AddPredictionToScore(data_, num_data_, score_ + curr_class * num_data_);
} }
/*! /*!
* \brief Adding prediction score, only used for training data. * \brief Adding prediction score, only used for training data.
...@@ -50,8 +50,8 @@ public: ...@@ -50,8 +50,8 @@ public:
* Based on which We can get prediction quckily. * Based on which We can get prediction quckily.
* \param tree_learner * \param tree_learner
*/ */
inline void AddScore(const TreeLearner* tree_learner) { inline void AddScore(const TreeLearner* tree_learner, int curr_class) {
tree_learner->AddPredictionToScore(score_); tree_learner->AddPredictionToScore(score_ + curr_class * num_data_);
} }
/*! /*!
* \brief Using tree model to get prediction number, then adding to scores for parts of data * \brief Using tree model to get prediction number, then adding to scores for parts of data
...@@ -61,8 +61,8 @@ public: ...@@ -61,8 +61,8 @@ public:
* \param data_cnt Number of data that will be proccessed * \param data_cnt Number of data that will be proccessed
*/ */
inline void AddScore(const Tree* tree, const data_size_t* data_indices, inline void AddScore(const Tree* tree, const data_size_t* data_indices,
data_size_t data_cnt) { data_size_t data_cnt, int curr_class) {
tree->AddPredictionToScore(data_, data_indices, data_cnt, score_); tree->AddPredictionToScore(data_, data_indices, data_cnt, score_ + curr_class * num_data_);
} }
/*! \brief Pointer of score */ /*! \brief Pointer of score */
inline const score_t * score() { return score_; } inline const score_t * score() { return score_; }
...@@ -72,7 +72,7 @@ private: ...@@ -72,7 +72,7 @@ private:
data_size_t num_data_; data_size_t num_data_;
/*! \brief Pointer of data set */ /*! \brief Pointer of data set */
const Dataset* data_; const Dataset* data_;
/*! \brief scores for data set */ /*! \brief Scores for data set */
score_t* score_; score_t* score_;
}; };
......
...@@ -46,7 +46,6 @@ void OverallConfig::Set(const std::unordered_map<std::string, std::string>& para ...@@ -46,7 +46,6 @@ void OverallConfig::Set(const std::unordered_map<std::string, std::string>& para
boosting_config = new GBDTConfig(); boosting_config = new GBDTConfig();
} }
// sub-config setup // sub-config setup
network_config.Set(params); network_config.Set(params);
io_config.Set(params); io_config.Set(params);
...@@ -132,7 +131,29 @@ void OverallConfig::GetTaskType(const std::unordered_map<std::string, std::strin ...@@ -132,7 +131,29 @@ void OverallConfig::GetTaskType(const std::unordered_map<std::string, std::strin
} }
void OverallConfig::CheckParamConflict() { void OverallConfig::CheckParamConflict() {
GBDTConfig* gbdt_config = dynamic_cast<GBDTConfig*>(boosting_config); GBDTConfig* gbdt_config = dynamic_cast<GBDTConfig*>(boosting_config);
// check if objective_type, metric_type, and num_class match
bool objective_type_multiclass = (objective_type == std::string("multiclass"));
int num_class_check = gbdt_config->num_class;
if (objective_type_multiclass){
if (num_class_check <= 1){
Log::Fatal("You should specify number of class(>=2) for multiclass training.");
}
}
else {
if (task_type == TaskType::kTrain && num_class_check != 1){
Log::Fatal("Number of class must be 1 for non-multiclass training.");
}
}
for (std::string metric_type : metric_types){
bool metric_type_multiclass = ( metric_type == std::string("multi_logloss") || metric_type == std::string("multi_error"));
if ((objective_type_multiclass && !metric_type_multiclass)
|| (!objective_type_multiclass && metric_type_multiclass)){
Log::Fatal("Objective and metrics don't match.");
}
}
if (network_config.num_machines > 1) { if (network_config.num_machines > 1) {
is_parallel = true; is_parallel = true;
} else { } else {
...@@ -196,6 +217,8 @@ void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& pa ...@@ -196,6 +217,8 @@ void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& pa
GetFloat(params, "sigmoid", &sigmoid); GetFloat(params, "sigmoid", &sigmoid);
GetInt(params, "max_position", &max_position); GetInt(params, "max_position", &max_position);
CHECK(max_position > 0); CHECK(max_position > 0);
GetInt(params, "num_class", &num_class);
CHECK(num_class >= 1);
std::string tmp_str = ""; std::string tmp_str = "";
if (GetString(params, "label_gain", &tmp_str)) { if (GetString(params, "label_gain", &tmp_str)) {
label_gain = Common::StringToFloatArray(tmp_str, ','); label_gain = Common::StringToFloatArray(tmp_str, ',');
...@@ -212,6 +235,8 @@ void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& pa ...@@ -212,6 +235,8 @@ void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& pa
void MetricConfig::Set(const std::unordered_map<std::string, std::string>& params) { void MetricConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetFloat(params, "sigmoid", &sigmoid); GetFloat(params, "sigmoid", &sigmoid);
GetInt(params, "num_class", &num_class);
CHECK(num_class >= 1);
std::string tmp_str = ""; std::string tmp_str = "";
if (GetString(params, "label_gain", &tmp_str)) { if (GetString(params, "label_gain", &tmp_str)) {
label_gain = Common::StringToFloatArray(tmp_str, ','); label_gain = Common::StringToFloatArray(tmp_str, ',');
...@@ -268,6 +293,8 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par ...@@ -268,6 +293,8 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par
GetInt(params, "metric_freq", &output_freq); GetInt(params, "metric_freq", &output_freq);
CHECK(output_freq >= 0); CHECK(output_freq >= 0);
GetBool(params, "is_training_metric", &is_provide_training_metric); GetBool(params, "is_training_metric", &is_provide_training_metric);
GetInt(params, "num_class", &num_class);
CHECK(num_class >= 1);
} }
void GBDTConfig::GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params) { void GBDTConfig::GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params) {
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include "regression_metric.hpp" #include "regression_metric.hpp"
#include "binary_metric.hpp" #include "binary_metric.hpp"
#include "rank_metric.hpp" #include "rank_metric.hpp"
#include "multiclass_metric.hpp"
namespace LightGBM { namespace LightGBM {
...@@ -18,6 +19,10 @@ Metric* Metric::CreateMetric(const std::string& type, const MetricConfig& config ...@@ -18,6 +19,10 @@ Metric* Metric::CreateMetric(const std::string& type, const MetricConfig& config
return new AUCMetric(config); return new AUCMetric(config);
} else if (type == "ndcg") { } else if (type == "ndcg") {
return new NDCGMetric(config); return new NDCGMetric(config);
} else if (type == "multi_logloss"){
return new MultiLoglossMetric(config);
} else if (type == "multi_error"){
return new MultiErrorMetric(config);
} }
return nullptr; return nullptr;
} }
......
#ifndef LIGHTGBM_METRIC_MULTICLASS_METRIC_HPP_
#define LIGHTGBM_METRIC_MULTICLASS_METRIC_HPP_
#include <LightGBM/utils/log.h>
#include <LightGBM/metric.h>
#include <cmath>
namespace LightGBM {
/*!
* \brief Metric for multiclass task.
* Use static class "PointWiseLossCalculator" to calculate loss point-wise
*/
template<typename PointWiseLossCalculator>
class MulticlassMetric: public Metric {
public:
explicit MulticlassMetric(const MetricConfig& config) {
num_class_ = config.num_class;
}
virtual ~MulticlassMetric() {
}
void Init(const char* test_name, const Metadata& metadata, data_size_t num_data) override {
std::stringstream str_buf;
str_buf << test_name << "'s " << PointWiseLossCalculator::Name();
name_ = str_buf.str();
num_data_ = num_data;
// get label
label_ = metadata.label();
// get weights
weights_ = metadata.weights();
if (weights_ == nullptr) {
sum_weights_ = static_cast<float>(num_data_);
} else {
sum_weights_ = 0.0f;
for (data_size_t i = 0; i < num_data_; ++i) {
sum_weights_ += weights_[i];
}
}
}
const char* GetName() const override {
return name_.c_str();
}
bool is_bigger_better() const override {
return false;
}
std::vector<score_t> Eval(const score_t* score) const override {
score_t sum_loss = 0.0;
if (weights_ == nullptr) {
#pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) {
std::vector<score_t> rec(num_class_);
for (int k = 0; k < num_class_; ++k) {
rec[k] = score[k * num_data_ + i];
}
// add loss
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec);
}
} else {
#pragma omp parallel for schedule(static) reduction(+:sum_loss)
for (data_size_t i = 0; i < num_data_; ++i) {
std::vector<score_t> rec(num_class_);
for (int k = 0; k < num_class_; ++k) {
rec[k] = score[k * num_data_ + i];
}
// add loss
sum_loss += PointWiseLossCalculator::LossOnPoint(label_[i], rec) * weights_[i];
}
}
score_t loss = sum_loss / sum_weights_;
return std::vector<score_t>(1, loss);
}
private:
/*! \brief Output frequency */
int output_freq_;
/*! \brief Number of data */
data_size_t num_data_;
/*! \brief Number of classes */
int num_class_;
/*! \brief Pointer of label */
const float* label_;
/*! \brief Pointer of weighs */
const float* weights_;
/*! \brief Sum weights */
float sum_weights_;
/*! \brief Name of this test set */
std::string name_;
};
/*! \brief L2 loss for multiclass task */
class MultiErrorMetric: public MulticlassMetric<MultiErrorMetric> {
public:
explicit MultiErrorMetric(const MetricConfig& config) :MulticlassMetric<MultiErrorMetric>(config) {}
inline static score_t LossOnPoint(float label, std::vector<score_t> score) {
size_t k = static_cast<size_t>(label);
for (size_t i = 0; i < score.size(); ++i){
if (i != k && score[i] > score[k]) {
return 0.0f;
}
}
return 1.0f;
}
inline static const char* Name() {
return "multi error";
}
};
/*! \brief Logloss for multiclass task */
class MultiLoglossMetric: public MulticlassMetric<MultiLoglossMetric> {
public:
explicit MultiLoglossMetric(const MetricConfig& config) :MulticlassMetric<MultiLoglossMetric>(config) {}
inline static score_t LossOnPoint(float label, std::vector<score_t> score) {
size_t k = static_cast<size_t>(label);
Common::Softmax(&score);
if (score[k] > kEpsilon) {
return -std::log(score[k]);
} else {
return -std::log(kEpsilon);
}
}
inline static const char* Name() {
return "multi logloss";
}
};
} // namespace LightGBM
#endif // LightGBM_METRIC_MULTICLASS_METRIC_HPP_
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
namespace LightGBM { namespace LightGBM {
/*! /*!
* \brief Objective funtion for binary classification * \brief Objective function for binary classification
*/ */
class BinaryLogloss: public ObjectiveFunction { class BinaryLogloss: public ObjectiveFunction {
public: public:
......
#ifndef LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_
#define LIGHTGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_
#include <LightGBM/objective_function.h>
#include <cstring>
#include <cmath>
namespace LightGBM {
/*!
* \brief Objective function for multiclass classification
*/
class MulticlassLogloss: public ObjectiveFunction {
public:
explicit MulticlassLogloss(const ObjectiveConfig& config)
:label_int_(nullptr) {
num_class_ = config.num_class;
}
~MulticlassLogloss() {
if (label_int_ != nullptr) { delete[] label_int_; }
}
void Init(const Metadata& metadata, data_size_t num_data) override {
num_data_ = num_data;
label_ = metadata.label();
weights_ = metadata.weights();
label_int_ = new int[num_data_];
for (int i = 0; i < num_data_; ++i){
label_int_[i] = static_cast<int>(label_[i]);
if (label_int_[i] < 0 || label_int_[i] >= num_class_) {
Log::Fatal("Label must be in [0, %d), but find %d in label", num_class_, label_int_[i]);
}
}
}
void GetGradients(const score_t* score, score_t* gradients, score_t* hessians) const override {
if (weights_ == nullptr) {
#pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data_; ++i) {
std::vector<score_t> rec(num_class_);
for (int k = 0; k < num_class_; ++k){
rec[k] = score[k * num_data_ + i];
}
Common::Softmax(&rec);
for (int k = 0; k < num_class_; ++k) {
score_t p = rec[k];
if (label_int_[i] == k) {
gradients[k * num_data_ + i] = p - 1.0f;
} else {
gradients[k * num_data_ + i] = p;
}
hessians[k * num_data_ + i] = 2.0f * p * (1.0f - p);
}
}
} else {
#pragma omp parallel for schedule(static)
for (data_size_t i = 0; i < num_data_; ++i) {
std::vector<score_t> rec(num_class_);
for (int k = 0; k < num_class_; ++k){
rec[k] = score[k * num_data_ + i];
}
Common::Softmax(&rec);
for (int k = 0; k < num_class_; ++k) {
float p = rec[k];
if (label_int_[i] == k) {
gradients[k * num_data_ + i] = (p - 1.0f) * weights_[i];
} else {
gradients[k * num_data_ + i] = p * weights_[i];
}
hessians[k * num_data_ + i] = 2.0f * p * (1.0f - p) * weights_[i];
}
}
}
}
float GetSigmoid() const override {
return -1.0f;
}
private:
/*! \brief Number of data */
data_size_t num_data_;
/*! \brief Number of classes */
int num_class_;
/*! \brief Pointer of label */
const float* label_;
/*! \brief Corresponding integers of label_ */
int* label_int_;
/*! \brief Weights for data */
const float* weights_;
};
} // namespace LightGBM
#endif // LightGBM_OBJECTIVE_MULTICLASS_OBJECTIVE_HPP_
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include "regression_objective.hpp" #include "regression_objective.hpp"
#include "binary_objective.hpp" #include "binary_objective.hpp"
#include "rank_objective.hpp" #include "rank_objective.hpp"
#include "multiclass_objective.hpp"
namespace LightGBM { namespace LightGBM {
...@@ -12,6 +13,8 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string& ...@@ -12,6 +13,8 @@ ObjectiveFunction* ObjectiveFunction::CreateObjectiveFunction(const std::string&
return new BinaryLogloss(config); return new BinaryLogloss(config);
} else if (type == "lambdarank") { } else if (type == "lambdarank") {
return new LambdarankNDCG(config); return new LambdarankNDCG(config);
} else if (type == "multiclass") {
return new MulticlassLogloss(config);
} }
return nullptr; return nullptr;
} }
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
namespace LightGBM { namespace LightGBM {
/*! /*!
* \brief Objective funtion for Lambdrank with NDCG * \brief Objective function for Lambdrank with NDCG
*/ */
class LambdarankNDCG: public ObjectiveFunction { class LambdarankNDCG: public ObjectiveFunction {
public: public:
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
namespace LightGBM { namespace LightGBM {
/*! /*!
* \brief Objective funtion for regression * \brief Objective function for regression
*/ */
class RegressionL2loss: public ObjectiveFunction { class RegressionL2loss: public ObjectiveFunction {
public: public:
......
...@@ -185,11 +185,13 @@ ...@@ -185,11 +185,13 @@
<ClInclude Include="..\src\metric\binary_metric.hpp" /> <ClInclude Include="..\src\metric\binary_metric.hpp" />
<ClInclude Include="..\src\metric\rank_metric.hpp" /> <ClInclude Include="..\src\metric\rank_metric.hpp" />
<ClInclude Include="..\src\metric\regression_metric.hpp" /> <ClInclude Include="..\src\metric\regression_metric.hpp" />
<ClInclude Include="..\src\metric\multiclass_metric.hpp" />
<ClInclude Include="..\src\network\linkers.h" /> <ClInclude Include="..\src\network\linkers.h" />
<ClInclude Include="..\src\network\socket_wrapper.hpp" /> <ClInclude Include="..\src\network\socket_wrapper.hpp" />
<ClInclude Include="..\src\objective\binary_objective.hpp" /> <ClInclude Include="..\src\objective\binary_objective.hpp" />
<ClInclude Include="..\src\objective\rank_objective.hpp" /> <ClInclude Include="..\src\objective\rank_objective.hpp" />
<ClInclude Include="..\src\objective\regression_objective.hpp" /> <ClInclude Include="..\src\objective\regression_objective.hpp" />
<ClInclude Include="..\src\objective\multiclass_objective.hpp" />
<ClInclude Include="..\src\treelearner\data_partition.hpp" /> <ClInclude Include="..\src\treelearner\data_partition.hpp" />
<ClInclude Include="..\src\treelearner\feature_histogram.hpp" /> <ClInclude Include="..\src\treelearner\feature_histogram.hpp" />
<ClInclude Include="..\src\treelearner\leaf_splits.hpp" /> <ClInclude Include="..\src\treelearner\leaf_splits.hpp" />
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment