Commit 01ed04df authored by wxchan's avatar wxchan Committed by Guolin Ke
Browse files

support init_score for multiclass classification (#62)

support init_score for multiclass classification (#62)
parent 665c9dba
Multiclass Classification Example
=====================
Here is an example for LightGBM to run multiclass classification task.
***You should copy executable file to this folder first.***
#### Training
For windows, by running following command in this folder:
```
lightgbm.exe config=train.conf
```
For linux, by running following command in this folder:
```
./lightgbm config=train.conf
```
#### Prediction
You should finish training first.
For windows, by running following command in this folder:
```
lightgbm.exe config=predict.conf
```
For linux, by running following command in this folder:
```
./lightgbm config=predict.conf
```
...@@ -86,6 +86,7 @@ enum TaskType { ...@@ -86,6 +86,7 @@ enum TaskType {
struct IOConfig: public ConfigBase { struct IOConfig: public ConfigBase {
public: public:
int max_bin = 256; int max_bin = 256;
int num_class = 1;
int data_random_seed = 1; int data_random_seed = 1;
std::string data_filename = ""; std::string data_filename = "";
std::vector<std::string> valid_data_filenames; std::vector<std::string> valid_data_filenames;
......
...@@ -41,14 +41,15 @@ public: ...@@ -41,14 +41,15 @@ public:
* \brief Initialization will load qurey level informations, since it is need for sampling data * \brief Initialization will load qurey level informations, since it is need for sampling data
* \param data_filename Filename of data * \param data_filename Filename of data
* \param init_score_filename Filename of initial score * \param init_score_filename Filename of initial score
* \param is_int_label True if label is int type * \param num_class Number of classes
*/ */
void Init(const char* data_filename, const char* init_score_filename); void Init(const char* data_filename, const char* init_score_filename, const int num_class);
/*! /*!
* \brief Initialize, only load initial score * \brief Initialize, only load initial score
* \param init_score_filename Filename of initial score * \param init_score_filename Filename of initial score
* \param num_class Number of classes
*/ */
void Init(const char* init_score_filename); void Init(const char* init_score_filename, const int num_class);
/*! /*!
* \brief Initial with binary memory * \brief Initial with binary memory
* \param memory Pointer to memory * \param memory Pointer to memory
...@@ -60,10 +61,11 @@ public: ...@@ -60,10 +61,11 @@ public:
/*! /*!
* \brief Initial work, will allocate space for label, weight(if exists) and query(if exists) * \brief Initial work, will allocate space for label, weight(if exists) and query(if exists)
* \param num_data Number of training data * \param num_data Number of training data
* \param num_class Number of classes
* \param weight_idx Index of weight column, < 0 means doesn't exists * \param weight_idx Index of weight column, < 0 means doesn't exists
* \param query_idx Index of query id column, < 0 means doesn't exists * \param query_idx Index of query id column, < 0 means doesn't exists
*/ */
void Init(data_size_t num_data, int weight_idx, int query_idx); void Init(data_size_t num_data, int num_class, int weight_idx, int query_idx);
/*! /*!
* \brief Partition label by used indices * \brief Partition label by used indices
...@@ -184,6 +186,8 @@ private: ...@@ -184,6 +186,8 @@ private:
const char* init_score_filename_; const char* init_score_filename_;
/*! \brief Number of data */ /*! \brief Number of data */
data_size_t num_data_; data_size_t num_data_;
/*! \brief Number of classes */
int num_class_;
/*! \brief Number of weights, used to check correct weight file */ /*! \brief Number of weights, used to check correct weight file */
data_size_t num_weights_; data_size_t num_weights_;
/*! \brief Label data */ /*! \brief Label data */
...@@ -234,7 +238,7 @@ public: ...@@ -234,7 +238,7 @@ public:
}; };
using PredictFunction = using PredictFunction =
std::function<double(const std::vector<std::pair<int, double>>&)>; std::function<std::vector<double>(const std::vector<std::pair<int, double>>&)>;
/*! \brief The main class of data set, /*! \brief The main class of data set,
* which are used to traning or validation * which are used to traning or validation
...@@ -398,6 +402,8 @@ private: ...@@ -398,6 +402,8 @@ private:
int num_total_features_; int num_total_features_;
/*! \brief Number of total data*/ /*! \brief Number of total data*/
data_size_t num_data_; data_size_t num_data_;
/*! \brief Number of classes*/
int num_class_;
/*! \brief Store some label level data*/ /*! \brief Store some label level data*/
Metadata metadata_; Metadata metadata_;
/*! \brief Random generator*/ /*! \brief Random generator*/
......
...@@ -124,10 +124,17 @@ void Application::LoadData() { ...@@ -124,10 +124,17 @@ void Application::LoadData() {
// need to continue train // need to continue train
if (boosting_->NumberOfSubModels() > 0) { if (boosting_->NumberOfSubModels() > 0) {
predictor = new Predictor(boosting_, config_.io_config.is_sigmoid, config_.predict_leaf_index, -1); predictor = new Predictor(boosting_, config_.io_config.is_sigmoid, config_.predict_leaf_index, -1);
if (config_.io_config.num_class == 1){
predict_fun = predict_fun =
[&predictor](const std::vector<std::pair<int, double>>& features) { [&predictor](const std::vector<std::pair<int, double>>& features) {
return predictor->PredictRawOneLine(features); return predictor->PredictRawOneLine(features);
}; };
} else {
predict_fun =
[&predictor](const std::vector<std::pair<int, double>>& features) {
return predictor->PredictMulticlassOneLine(features);
};
}
} }
// sync up random seed for data partition // sync up random seed for data partition
if (config_.is_parallel_find_bin) { if (config_.is_parallel_find_bin) {
......
...@@ -61,10 +61,10 @@ public: ...@@ -61,10 +61,10 @@ public:
* \param features Feature for this record * \param features Feature for this record
* \return Prediction result * \return Prediction result
*/ */
double PredictRawOneLine(const std::vector<std::pair<int, double>>& features) { std::vector<double> PredictRawOneLine(const std::vector<std::pair<int, double>>& features) {
const int tid = PutFeatureValuesToBuffer(features); const int tid = PutFeatureValuesToBuffer(features);
// get result without sigmoid transformation // get result without sigmoid transformation
return boosting_->PredictRaw(features_[tid], num_used_model_); return std::vector<double>(1, boosting_->PredictRaw(features_[tid], num_used_model_));
} }
/*! /*!
...@@ -83,10 +83,10 @@ public: ...@@ -83,10 +83,10 @@ public:
* \param features Feature of this record * \param features Feature of this record
* \return Prediction result * \return Prediction result
*/ */
double PredictOneLine(const std::vector<std::pair<int, double>>& features) { std::vector<double> PredictOneLine(const std::vector<std::pair<int, double>>& features) {
const int tid = PutFeatureValuesToBuffer(features); const int tid = PutFeatureValuesToBuffer(features);
// get result with sigmoid transform if needed // get result with sigmoid transform if needed
return boosting_->Predict(features_[tid], num_used_model_); return std::vector<double>(1, boosting_->Predict(features_[tid], num_used_model_));
} }
/*! /*!
...@@ -136,6 +136,7 @@ public: ...@@ -136,6 +136,7 @@ public:
if (num_class_ > 1) { if (num_class_ > 1) {
predict_fun = [this](const std::vector<std::pair<int, double>>& features){ predict_fun = [this](const std::vector<std::pair<int, double>>& features){
std::vector<double> prediction = PredictMulticlassOneLine(features); std::vector<double> prediction = PredictMulticlassOneLine(features);
Common::Softmax(&prediction);
std::stringstream result_stream_buf; std::stringstream result_stream_buf;
for (size_t i = 0; i < prediction.size(); ++i){ for (size_t i = 0; i < prediction.size(); ++i){
if (i > 0) { if (i > 0) {
...@@ -162,12 +163,12 @@ public: ...@@ -162,12 +163,12 @@ public:
else { else {
if (is_simgoid_) { if (is_simgoid_) {
predict_fun = [this](const std::vector<std::pair<int, double>>& features){ predict_fun = [this](const std::vector<std::pair<int, double>>& features){
return std::to_string(PredictOneLine(features)); return std::to_string(PredictOneLine(features)[0]);
}; };
} }
else { else {
predict_fun = [this](const std::vector<std::pair<int, double>>& features){ predict_fun = [this](const std::vector<std::pair<int, double>>& features){
return std::to_string(PredictRawOneLine(features)); return std::to_string(PredictRawOneLine(features)[0]);
}; };
} }
} }
......
...@@ -503,7 +503,6 @@ std::vector<double> GBDT::PredictMulticlass(const double* value, int num_used_mo ...@@ -503,7 +503,6 @@ std::vector<double> GBDT::PredictMulticlass(const double* value, int num_used_mo
ret[j] += models_[i * num_class_ + j] -> Predict(value); ret[j] += models_[i * num_class_ + j] -> Predict(value);
} }
} }
Common::Softmax(&ret);
return ret; return ret;
} }
......
...@@ -27,7 +27,7 @@ public: ...@@ -27,7 +27,7 @@ public:
const float* init_score = data->metadata().init_score(); const float* init_score = data->metadata().init_score();
// if exists initial score, will start from it // if exists initial score, will start from it
if (init_score != nullptr) { if (init_score != nullptr) {
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_ * num_class; ++i) {
score_[i] = init_score[i]; score_[i] = init_score[i];
} }
} }
......
...@@ -184,6 +184,7 @@ void OverallConfig::CheckParamConflict() { ...@@ -184,6 +184,7 @@ void OverallConfig::CheckParamConflict() {
void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) { void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "max_bin", &max_bin); GetInt(params, "max_bin", &max_bin);
CHECK(max_bin > 0); CHECK(max_bin > 0);
GetInt(params, "num_class", &num_class);
GetInt(params, "data_random_seed", &data_random_seed); GetInt(params, "data_random_seed", &data_random_seed);
if (!GetString(params, "data", &data_filename)) { if (!GetString(params, "data", &data_filename)) {
...@@ -236,7 +237,6 @@ void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& pa ...@@ -236,7 +237,6 @@ void ObjectiveConfig::Set(const std::unordered_map<std::string, std::string>& pa
void MetricConfig::Set(const std::unordered_map<std::string, std::string>& params) { void MetricConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetDouble(params, "sigmoid", &sigmoid); GetDouble(params, "sigmoid", &sigmoid);
GetInt(params, "num_class", &num_class); GetInt(params, "num_class", &num_class);
CHECK(num_class >= 1);
std::string tmp_str = ""; std::string tmp_str = "";
if (GetString(params, "label_gain", &tmp_str)) { if (GetString(params, "label_gain", &tmp_str)) {
label_gain = Common::StringToDoubleArray(tmp_str, ','); label_gain = Common::StringToDoubleArray(tmp_str, ',');
...@@ -294,7 +294,6 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par ...@@ -294,7 +294,6 @@ void BoostingConfig::Set(const std::unordered_map<std::string, std::string>& par
CHECK(output_freq >= 0); CHECK(output_freq >= 0);
GetBool(params, "is_training_metric", &is_provide_training_metric); GetBool(params, "is_training_metric", &is_provide_training_metric);
GetInt(params, "num_class", &num_class); GetInt(params, "num_class", &num_class);
CHECK(num_class >= 1);
} }
void GBDTConfig::GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params) { void GBDTConfig::GetTreeLearnerType(const std::unordered_map<std::string, std::string>& params) {
......
...@@ -20,6 +20,8 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename, ...@@ -20,6 +20,8 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
:data_filename_(data_filename), random_(io_config.data_random_seed), :data_filename_(data_filename), random_(io_config.data_random_seed),
max_bin_(io_config.max_bin), is_enable_sparse_(io_config.is_enable_sparse), predict_fun_(predict_fun) { max_bin_(io_config.max_bin), is_enable_sparse_(io_config.is_enable_sparse), predict_fun_(predict_fun) {
num_class_ = io_config.num_class;
CheckCanLoadFromBin(); CheckCanLoadFromBin();
if (is_loading_from_binfile_ && predict_fun != nullptr) { if (is_loading_from_binfile_ && predict_fun != nullptr) {
Log::Info("Cannot performing initialization of prediction by using binary file, using text file instead"); Log::Info("Cannot performing initialization of prediction by using binary file, using text file instead");
...@@ -28,7 +30,7 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename, ...@@ -28,7 +30,7 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
if (!is_loading_from_binfile_) { if (!is_loading_from_binfile_) {
// load weight, query information and initilize score // load weight, query information and initilize score
metadata_.Init(data_filename, init_score_filename); metadata_.Init(data_filename, init_score_filename, num_class_);
// create text reader // create text reader
text_reader_ = new TextReader<data_size_t>(data_filename, io_config.has_header); text_reader_ = new TextReader<data_size_t>(data_filename, io_config.has_header);
...@@ -152,7 +154,7 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename, ...@@ -152,7 +154,7 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
} }
} else { } else {
// only need to load initilize score, other meta data will be loaded from bin flie // only need to load initilize score, other meta data will be loaded from bin flie
metadata_.Init(init_score_filename); metadata_.Init(init_score_filename, num_class_);
Log::Info("Loading data set from binary file"); Log::Info("Loading data set from binary file");
parser_ = nullptr; parser_ = nullptr;
text_reader_ = nullptr; text_reader_ = nullptr;
...@@ -436,7 +438,7 @@ void Dataset::LoadTrainData(int rank, int num_machines, bool is_pre_partition, b ...@@ -436,7 +438,7 @@ void Dataset::LoadTrainData(int rank, int num_machines, bool is_pre_partition, b
// construct feature bin mappers // construct feature bin mappers
ConstructBinMappers(rank, num_machines, sample_data); ConstructBinMappers(rank, num_machines, sample_data);
// initialize label // initialize label
metadata_.Init(num_data_, weight_idx_, group_idx_); metadata_.Init(num_data_, num_class_, weight_idx_, group_idx_);
// extract features // extract features
ExtractFeaturesFromMemory(); ExtractFeaturesFromMemory();
} else { } else {
...@@ -446,7 +448,7 @@ void Dataset::LoadTrainData(int rank, int num_machines, bool is_pre_partition, b ...@@ -446,7 +448,7 @@ void Dataset::LoadTrainData(int rank, int num_machines, bool is_pre_partition, b
// construct feature bin mappers // construct feature bin mappers
ConstructBinMappers(rank, num_machines, sample_data); ConstructBinMappers(rank, num_machines, sample_data);
// initialize label // initialize label
metadata_.Init(num_data_, weight_idx_, group_idx_); metadata_.Init(num_data_, num_class_, weight_idx_, group_idx_);
// extract features // extract features
ExtractFeaturesFromFile(); ExtractFeaturesFromFile();
...@@ -471,7 +473,7 @@ void Dataset::LoadValidationData(const Dataset* train_set, bool use_two_round_lo ...@@ -471,7 +473,7 @@ void Dataset::LoadValidationData(const Dataset* train_set, bool use_two_round_lo
// read data in memory // read data in memory
LoadDataToMemory(0, 1, false); LoadDataToMemory(0, 1, false);
// initialize label // initialize label
metadata_.Init(num_data_, weight_idx_, group_idx_); metadata_.Init(num_data_, num_class_, weight_idx_, group_idx_);
features_.clear(); features_.clear();
// copy feature bin mapper data // copy feature bin mapper data
for (Feature* feature : train_set->features_) { for (Feature* feature : train_set->features_) {
...@@ -487,7 +489,7 @@ void Dataset::LoadValidationData(const Dataset* train_set, bool use_two_round_lo ...@@ -487,7 +489,7 @@ void Dataset::LoadValidationData(const Dataset* train_set, bool use_two_round_lo
// Get number of lines of data file // Get number of lines of data file
num_data_ = static_cast<data_size_t>(text_reader_->CountLine()); num_data_ = static_cast<data_size_t>(text_reader_->CountLine());
// initialize label // initialize label
metadata_.Init(num_data_, weight_idx_, group_idx_); metadata_.Init(num_data_, num_class_, weight_idx_, group_idx_);
features_.clear(); features_.clear();
// copy feature bin mapper data // copy feature bin mapper data
for (Feature* feature : train_set->features_) { for (Feature* feature : train_set->features_) {
...@@ -545,7 +547,7 @@ void Dataset::ExtractFeaturesFromMemory() { ...@@ -545,7 +547,7 @@ void Dataset::ExtractFeaturesFromMemory() {
} }
} else { } else {
// if need to prediction with initial model // if need to prediction with initial model
float* init_score = new float[num_data_]; float* init_score = new float[num_data_ * num_class_];
#pragma omp parallel for schedule(guided) private(oneline_features) firstprivate(tmp_label) #pragma omp parallel for schedule(guided) private(oneline_features) firstprivate(tmp_label)
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
const int tid = omp_get_thread_num(); const int tid = omp_get_thread_num();
...@@ -553,7 +555,10 @@ void Dataset::ExtractFeaturesFromMemory() { ...@@ -553,7 +555,10 @@ void Dataset::ExtractFeaturesFromMemory() {
// parser // parser
parser_->ParseOneLine(text_reader_->Lines()[i].c_str(), &oneline_features, &tmp_label); parser_->ParseOneLine(text_reader_->Lines()[i].c_str(), &oneline_features, &tmp_label);
// set initial score // set initial score
init_score[i] = static_cast<float>(predict_fun_(oneline_features)); std::vector<double> oneline_init_score = predict_fun_(oneline_features);
for (int k = 0; k < num_class_; ++k){
init_score[k * num_data_ + i] = static_cast<float>(oneline_init_score[k]);
}
// set label // set label
metadata_.SetLabelAt(i, static_cast<float>(tmp_label)); metadata_.SetLabelAt(i, static_cast<float>(tmp_label));
// free processed line: // free processed line:
...@@ -577,7 +582,7 @@ void Dataset::ExtractFeaturesFromMemory() { ...@@ -577,7 +582,7 @@ void Dataset::ExtractFeaturesFromMemory() {
} }
} }
// metadata_ will manage space of init_score // metadata_ will manage space of init_score
metadata_.SetInitScore(init_score, num_data_); metadata_.SetInitScore(init_score, num_data_ * num_class_);
delete[] init_score; delete[] init_score;
} }
...@@ -593,7 +598,7 @@ void Dataset::ExtractFeaturesFromMemory() { ...@@ -593,7 +598,7 @@ void Dataset::ExtractFeaturesFromMemory() {
void Dataset::ExtractFeaturesFromFile() { void Dataset::ExtractFeaturesFromFile() {
float* init_score = nullptr; float* init_score = nullptr;
if (predict_fun_ != nullptr) { if (predict_fun_ != nullptr) {
init_score = new float[num_data_]; init_score = new float[num_data_ * num_class_];
} }
std::function<void(data_size_t, const std::vector<std::string>&)> process_fun = std::function<void(data_size_t, const std::vector<std::string>&)> process_fun =
[this, &init_score] [this, &init_score]
...@@ -608,7 +613,10 @@ void Dataset::ExtractFeaturesFromFile() { ...@@ -608,7 +613,10 @@ void Dataset::ExtractFeaturesFromFile() {
parser_->ParseOneLine(lines[i].c_str(), &oneline_features, &tmp_label); parser_->ParseOneLine(lines[i].c_str(), &oneline_features, &tmp_label);
// set initial score // set initial score
if (init_score != nullptr) { if (init_score != nullptr) {
init_score[start_idx + i] = static_cast<float>(predict_fun_(oneline_features)); std::vector<double> oneline_init_score = predict_fun_(oneline_features);
for (int k = 0; k < num_class_; ++k){
init_score[k * num_data_ + start_idx + i] = static_cast<float>(oneline_init_score[k]);
}
} }
// set label // set label
metadata_.SetLabelAt(start_idx + i, static_cast<float>(tmp_label)); metadata_.SetLabelAt(start_idx + i, static_cast<float>(tmp_label));
...@@ -640,7 +648,7 @@ void Dataset::ExtractFeaturesFromFile() { ...@@ -640,7 +648,7 @@ void Dataset::ExtractFeaturesFromFile() {
// metadata_ will manage space of init_score // metadata_ will manage space of init_score
if (init_score != nullptr) { if (init_score != nullptr) {
metadata_.SetInitScore(init_score, num_data_); metadata_.SetInitScore(init_score, num_data_ * num_class_);
delete[] init_score; delete[] init_score;
} }
......
...@@ -14,9 +14,10 @@ Metadata::Metadata() ...@@ -14,9 +14,10 @@ Metadata::Metadata()
} }
void Metadata::Init(const char * data_filename, const char* init_score_filename) { void Metadata::Init(const char * data_filename, const char* init_score_filename, const int num_class) {
data_filename_ = data_filename; data_filename_ = data_filename;
init_score_filename_ = init_score_filename; init_score_filename_ = init_score_filename;
num_class_ = num_class;
// for lambdarank, it needs query data for partition data in parallel learning // for lambdarank, it needs query data for partition data in parallel learning
LoadQueryBoundaries(); LoadQueryBoundaries();
LoadWeights(); LoadWeights();
...@@ -24,8 +25,9 @@ void Metadata::Init(const char * data_filename, const char* init_score_filename) ...@@ -24,8 +25,9 @@ void Metadata::Init(const char * data_filename, const char* init_score_filename)
LoadInitialScore(); LoadInitialScore();
} }
void Metadata::Init(const char* init_score_filename) { void Metadata::Init(const char* init_score_filename, const int num_class) {
init_score_filename_ = init_score_filename; init_score_filename_ = init_score_filename;
num_class_ = num_class;
LoadInitialScore(); LoadInitialScore();
} }
...@@ -40,8 +42,9 @@ Metadata::~Metadata() { ...@@ -40,8 +42,9 @@ Metadata::~Metadata() {
} }
void Metadata::Init(data_size_t num_data, int weight_idx, int query_idx) { void Metadata::Init(data_size_t num_data, int num_class, int weight_idx, int query_idx) {
num_data_ = num_data; num_data_ = num_data;
num_class_ = num_class;
label_ = new float[num_data_]; label_ = new float[num_data_];
if (weight_idx >= 0) { if (weight_idx >= 0) {
if (weights_ != nullptr) { if (weights_ != nullptr) {
...@@ -200,9 +203,11 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data ...@@ -200,9 +203,11 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
if (init_score_ != nullptr) { if (init_score_ != nullptr) {
float* old_scores = init_score_; float* old_scores = init_score_;
num_init_score_ = num_data_; num_init_score_ = num_data_;
init_score_ = new float[num_init_score_]; init_score_ = new float[num_init_score_ * num_class_];
for (int k = 0; k < num_class_; ++k){
for (size_t i = 0; i < used_data_indices.size(); ++i) { for (size_t i = 0; i < used_data_indices.size(); ++i) {
init_score_[i] = old_scores[used_data_indices[i]]; init_score_[k * num_data_ + i] = old_scores[k * num_all_data + used_data_indices[i]];
}
} }
delete[] old_scores; delete[] old_scores;
} }
...@@ -214,13 +219,13 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data ...@@ -214,13 +219,13 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
void Metadata::SetInitScore(const float* init_score, data_size_t len) { void Metadata::SetInitScore(const float* init_score, data_size_t len) {
if (num_data_ != len) { if (len != num_data_ * num_class_) {
Log::Fatal("len of initial score is not same with #data"); Log::Fatal("Length of initial score is not same with number of data");
} }
if (init_score_ != nullptr) { delete[] init_score_; } if (init_score_ != nullptr) { delete[] init_score_; }
num_init_score_ = num_data_; num_init_score_ = num_data_;
init_score_ = new float[num_init_score_]; init_score_ = new float[len];
for (data_size_t i = 0; i < num_init_score_; ++i) { for (data_size_t i = 0; i < len; ++i) {
init_score_[i] = init_score[i]; init_score_[i] = init_score[i];
} }
} }
...@@ -253,12 +258,28 @@ void Metadata::LoadInitialScore() { ...@@ -253,12 +258,28 @@ void Metadata::LoadInitialScore() {
Log::Info("Start loading initial scores"); Log::Info("Start loading initial scores");
num_init_score_ = static_cast<data_size_t>(reader.Lines().size()); num_init_score_ = static_cast<data_size_t>(reader.Lines().size());
init_score_ = new float[num_init_score_];
init_score_ = new float[num_init_score_ * num_class_];
double tmp = 0.0f; double tmp = 0.0f;
if (num_class_ == 1){
for (data_size_t i = 0; i < num_init_score_; ++i) { for (data_size_t i = 0; i < num_init_score_; ++i) {
Common::Atof(reader.Lines()[i].c_str(), &tmp); Common::Atof(reader.Lines()[i].c_str(), &tmp);
init_score_[i] = static_cast<float>(tmp); init_score_[i] = static_cast<float>(tmp);
} }
} else {
std::vector<std::string> oneline_init_score;
for (data_size_t i = 0; i < num_init_score_; ++i) {
oneline_init_score = Common::Split(reader.Lines()[i].c_str(), '\t');
if (static_cast<int>(oneline_init_score.size()) != num_class_){
Log::Fatal("Invalid initial score file. Redundant or insufficient columns.");
}
for (int k = 0; k < num_class_; ++k) {
Common::Atof(oneline_init_score[k].c_str(), &tmp);
init_score_[k * num_init_score_ + i] = static_cast<float>(tmp);
}
}
}
} }
void Metadata::LoadQueryBoundaries() { void Metadata::LoadQueryBoundaries() {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment