Commit 3a06ce35 authored by Guolin Ke's avatar Guolin Ke
Browse files

clean code

parent a057afec
...@@ -192,8 +192,6 @@ private: ...@@ -192,8 +192,6 @@ private:
int16_t* label_int_; int16_t* label_int_;
/*! \brief Weights data */ /*! \brief Weights data */
float* weights_; float* weights_;
/*! \brief Queries data */
data_size_t* queries_;
/*! \brief Query boundaries */ /*! \brief Query boundaries */
data_size_t* query_boundaries_; data_size_t* query_boundaries_;
/*! \brief Query weights */ /*! \brief Query weights */
...@@ -204,6 +202,8 @@ private: ...@@ -204,6 +202,8 @@ private:
data_size_t num_init_score_; data_size_t num_init_score_;
/*! \brief Initial score */ /*! \brief Initial score */
score_t* init_score_; score_t* init_score_;
/*! \brief Queries data */
data_size_t* queries_;
}; };
......
...@@ -38,12 +38,13 @@ public: ...@@ -38,12 +38,13 @@ public:
char* buffer_process = new char[buffer_size]; char* buffer_process = new char[buffer_size];
// buffer used for the file reading // buffer used for the file reading
char* buffer_read = new char[buffer_size]; char* buffer_read = new char[buffer_size];
size_t read_cnt = 0;
if (skip_bytes > 0) { if (skip_bytes > 0) {
// skip first k bytes // skip first k bytes
fread(buffer_process, 1, skip_bytes, file); read_cnt = fread(buffer_process, 1, skip_bytes, file);
} }
// read first block // read first block
size_t read_cnt = fread(buffer_process, 1, buffer_size, file); read_cnt = fread(buffer_process, 1, buffer_size, file);
size_t last_read_cnt = 0; size_t last_read_cnt = 0;
while (read_cnt > 0) { while (read_cnt > 0) {
// strat read thread // strat read thread
......
...@@ -34,7 +34,7 @@ public: ...@@ -34,7 +34,7 @@ public:
#else #else
file = fopen(filename, "r"); file = fopen(filename, "r");
#endif #endif
std::stringstream ss; std::stringstream str_buf;
int read_c = -1; int read_c = -1;
read_c = fgetc(file); read_c = fgetc(file);
while (read_c != EOF) { while (read_c != EOF) {
...@@ -42,7 +42,7 @@ public: ...@@ -42,7 +42,7 @@ public:
if (tmp_ch == '\n' || tmp_ch == '\r') { if (tmp_ch == '\n' || tmp_ch == '\r') {
break; break;
} }
ss << tmp_ch; str_buf << tmp_ch;
++skip_bytes_; ++skip_bytes_;
read_c = fgetc(file); read_c = fgetc(file);
} }
...@@ -55,7 +55,7 @@ public: ...@@ -55,7 +55,7 @@ public:
++skip_bytes_; ++skip_bytes_;
} }
fclose(file); fclose(file);
first_line_ = ss.str(); first_line_ = str_buf.str();
Log::Info("skip header:\"%s\" in file %s", first_line_.c_str(), filename_); Log::Info("skip header:\"%s\" in file %s", first_line_.c_str(), filename_);
} }
} }
......
...@@ -275,21 +275,21 @@ void GBDT::Boosting() { ...@@ -275,21 +275,21 @@ void GBDT::Boosting() {
std::string GBDT::ModelsToString() const { std::string GBDT::ModelsToString() const {
// serialize this object to string // serialize this object to string
std::stringstream ss; std::stringstream str_buf;
// output label index // output label index
ss << "label_index=" << label_idx_ << std::endl; str_buf << "label_index=" << label_idx_ << std::endl;
// output max_feature_idx // output max_feature_idx
ss << "max_feature_idx=" << max_feature_idx_ << std::endl; str_buf << "max_feature_idx=" << max_feature_idx_ << std::endl;
// output sigmoid parameter // output sigmoid parameter
ss << "sigmoid=" << object_function_->GetSigmoid() << std::endl; str_buf << "sigmoid=" << object_function_->GetSigmoid() << std::endl;
ss << std::endl; str_buf << std::endl;
// output tree models // output tree models
for (size_t i = 0; i < models_.size(); ++i) { for (size_t i = 0; i < models_.size(); ++i) {
ss << "Tree=" << i << std::endl; str_buf << "Tree=" << i << std::endl;
ss << models_[i]->ToString() << std::endl; str_buf << models_[i]->ToString() << std::endl;
} }
return ss.str(); return str_buf.str();
} }
void GBDT::ModelsFromString(const std::string& model_str, int num_used_model) { void GBDT::ModelsFromString(const std::string& model_str, int num_used_model) {
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
#include <vector> #include <vector>
#include <utility> #include <utility>
#include <string> #include <string>
#include <sstream>
namespace LightGBM { namespace LightGBM {
...@@ -36,8 +37,8 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename, ...@@ -36,8 +37,8 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
if (io_config.has_header) { if (io_config.has_header) {
std::string first_line = text_reader_->first_line(); std::string first_line = text_reader_->first_line();
feature_names_ = Common::Split(first_line.c_str(), "\t ,"); feature_names_ = Common::Split(first_line.c_str(), "\t ,");
for (int i = 0; i < feature_names_.size(); ++i) { for (size_t i = 0; i < feature_names_.size(); ++i) {
name2idx[feature_names_[i]] = i; name2idx[feature_names_[i]] = static_cast<int>(i);
} }
} }
std::string name_prefix("name:"); std::string name_prefix("name:");
...@@ -48,14 +49,25 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename, ...@@ -48,14 +49,25 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
std::string name = io_config.label_column.substr(name_prefix.size()); std::string name = io_config.label_column.substr(name_prefix.size());
if (name2idx.count(name) > 0) { if (name2idx.count(name) > 0) {
label_idx_ = name2idx[name]; label_idx_ = name2idx[name];
Log::Info("use %s column as label", name.c_str());
} else { } else {
Log::Fatal("cannot find label column: %s in data file", name.c_str()); Log::Fatal("cannot find label column: %s in data file", name.c_str());
} }
} else { } else {
Common::Atoi(io_config.label_column.c_str(), &label_idx_); size_t pos = 0;
label_idx_ = std::stoi(io_config.label_column, &pos);
if (pos != io_config.label_column.size()) {
Log::Fatal("label_column is not a number, \
if you want to use column name, \
please add prefix \"name:\" before column name");
}
Log::Info("use %d-th column as label", label_idx_);
} }
} }
if (feature_names_.size() > 0) {
// erase label column name
feature_names_.erase(feature_names_.begin() + label_idx_);
}
// load ignore columns // load ignore columns
if (io_config.ignore_column.size() > 0) { if (io_config.ignore_column.size() > 0) {
if (Common::StartsWith(io_config.ignore_column, name_prefix)) { if (Common::StartsWith(io_config.ignore_column, name_prefix)) {
...@@ -72,8 +84,13 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename, ...@@ -72,8 +84,13 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
} }
} else { } else {
for (auto token : Common::Split(io_config.ignore_column.c_str(), ',')) { for (auto token : Common::Split(io_config.ignore_column.c_str(), ',')) {
int tmp = 0; size_t pos = 0;
Common::Atoi(token.c_str(), &tmp); int tmp = std::stoi(token, &pos);
if (pos != token.size()) {
Log::Fatal("ignore_column is not a number, \
if you want to use column name, \
please add prefix \"name:\" before column name");
}
// skip for label column // skip for label column
if (tmp > label_idx_) { tmp -= 1; } if (tmp > label_idx_) { tmp -= 1; }
ignore_features_.emplace(tmp); ignore_features_.emplace(tmp);
...@@ -88,11 +105,19 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename, ...@@ -88,11 +105,19 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
std::string name = io_config.weight_column.substr(name_prefix.size()); std::string name = io_config.weight_column.substr(name_prefix.size());
if (name2idx.count(name) > 0) { if (name2idx.count(name) > 0) {
weight_idx_ = name2idx[name]; weight_idx_ = name2idx[name];
Log::Info("use %s column as weight", name.c_str());
} else { } else {
Log::Fatal("cannot find weight column: %s in data file", name.c_str()); Log::Fatal("cannot find weight column: %s in data file", name.c_str());
} }
} else { } else {
Common::Atoi(io_config.weight_column.c_str(), &weight_idx_); size_t pos = 0;
weight_idx_ = std::stoi(io_config.weight_column, &pos);
if (pos != io_config.weight_column.size()) {
Log::Fatal("weight_column is not a number, \
if you want to use column name, \
please add prefix \"name:\" before column name");
}
Log::Info("use %d-th column as weight", weight_idx_);
} }
// skip for label column // skip for label column
if (weight_idx_ > label_idx_) { if (weight_idx_ > label_idx_) {
...@@ -106,11 +131,19 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename, ...@@ -106,11 +131,19 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
std::string name = io_config.group_column.substr(name_prefix.size()); std::string name = io_config.group_column.substr(name_prefix.size());
if (name2idx.count(name) > 0) { if (name2idx.count(name) > 0) {
group_idx_ = name2idx[name]; group_idx_ = name2idx[name];
Log::Info("use %s column as group/query id", name.c_str());
} else { } else {
Log::Fatal("cannot find group/query column: %s in data file", name.c_str()); Log::Fatal("cannot find group/query column: %s in data file", name.c_str());
} }
} else { } else {
Common::Atoi(io_config.group_column.c_str(), &group_idx_); size_t pos = 0;
group_idx_ = std::stoi(io_config.group_column, &pos);
if (pos != io_config.group_column.size()) {
Log::Fatal("group_column is not a number, \
if you want to use column name, \
please add prefix \"name:\" before column name");
}
Log::Info("use %d-th column as group/query id", group_idx_);
} }
// skip for label column // skip for label column
if (group_idx_ > label_idx_) { if (group_idx_ > label_idx_) {
...@@ -279,6 +312,21 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector< ...@@ -279,6 +312,21 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
// -1 means doesn't use this feature // -1 means doesn't use this feature
used_feature_map_ = std::vector<int>(sample_values.size(), -1); used_feature_map_ = std::vector<int>(sample_values.size(), -1);
num_total_features_ = static_cast<int>(sample_values.size()); num_total_features_ = static_cast<int>(sample_values.size());
// check the range of label_idx, weight_idx and group_idx
CHECK(label_idx_ >= 0 && label_idx_ <= num_total_features_);
CHECK(weight_idx_ < 0 || weight_idx_ < num_total_features_);
CHECK(group_idx_ < 0 || group_idx_ < num_total_features_);
// fill feature_names_ if not header
if (feature_names_.size() <= 0) {
for (int i = 0; i < num_total_features_; ++i) {
std::stringstream str_buf;
str_buf << "Column_" << i;
feature_names_.push_back(str_buf.str());
}
}
// start find bins // start find bins
if (num_machines == 1) { if (num_machines == 1) {
std::vector<BinMapper*> bin_mappers(sample_values.size()); std::vector<BinMapper*> bin_mappers(sample_values.size());
...@@ -295,7 +343,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector< ...@@ -295,7 +343,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
for (size_t i = 0; i < sample_values.size(); ++i) { for (size_t i = 0; i < sample_values.size(); ++i) {
if (bin_mappers[i] == nullptr) { if (bin_mappers[i] == nullptr) {
Log::Error("Ignore Feature %d ", i); Log::Error("Ignore Feature %s ", feature_names_[i].c_str());
} }
else if (!bin_mappers[i]->is_trival()) { else if (!bin_mappers[i]->is_trival()) {
// map real feature index to used feature index // map real feature index to used feature index
...@@ -305,7 +353,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector< ...@@ -305,7 +353,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
num_data_, is_enable_sparse_)); num_data_, is_enable_sparse_));
} else { } else {
// if feature is trival(only 1 bin), free spaces // if feature is trival(only 1 bin), free spaces
Log::Error("Feature %d only contains one value, will be ignored", i); Log::Error("Feature %s only contains one value, will be ignored", feature_names_[i].c_str());
delete bin_mappers[i]; delete bin_mappers[i];
} }
} }
...@@ -353,7 +401,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector< ...@@ -353,7 +401,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
// restore features bins from buffer // restore features bins from buffer
for (int i = 0; i < total_num_feature; ++i) { for (int i = 0; i < total_num_feature; ++i) {
if (ignore_features_.count(i) > 0) { if (ignore_features_.count(i) > 0) {
Log::Error("Ignore Feature %d ", i); Log::Error("Ignore Feature %s ", feature_names_[i].c_str());
continue; continue;
} }
BinMapper* bin_mapper = new BinMapper(); BinMapper* bin_mapper = new BinMapper();
...@@ -362,7 +410,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector< ...@@ -362,7 +410,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
used_feature_map_[i] = static_cast<int>(features_.size()); used_feature_map_[i] = static_cast<int>(features_.size());
features_.push_back(new Feature(static_cast<int>(i), bin_mapper, num_data_, is_enable_sparse_)); features_.push_back(new Feature(static_cast<int>(i), bin_mapper, num_data_, is_enable_sparse_));
} else { } else {
Log::Error("Feature %d only contains one value, will be ignored", i); Log::Error("Feature %s only contains one value, will be ignored", feature_names_[i].c_str());
delete bin_mapper; delete bin_mapper;
} }
} }
...@@ -377,7 +425,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector< ...@@ -377,7 +425,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
void Dataset::LoadTrainData(int rank, int num_machines, bool is_pre_partition, bool use_two_round_loading) { void Dataset::LoadTrainData(int rank, int num_machines, bool is_pre_partition, bool use_two_round_loading) {
// don't support query id in data file when training parallel // don't support query id in data file when training in parallel
if (num_machines > 1 && !is_pre_partition) { if (num_machines > 1 && !is_pre_partition) {
if (group_idx_ > 0) { if (group_idx_ > 0) {
Log::Fatal("Don't support query id in data file when training parallel without pre-partition. \ Log::Fatal("Don't support query id in data file when training parallel without pre-partition. \
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include <iostream> #include <iostream>
#include <fstream> #include <fstream>
#include <functional>
namespace LightGBM { namespace LightGBM {
...@@ -20,37 +21,53 @@ void GetStatistic(const char* str, int* comma_cnt, int* tab_cnt, int* colon_cnt) ...@@ -20,37 +21,53 @@ void GetStatistic(const char* str, int* comma_cnt, int* tab_cnt, int* colon_cnt)
} }
} }
bool CheckHasLabelForLibsvm(std::string& str) { int GetLabelIdxForLibsvm(std::string& str, int num_features, int label_idx) {
if (num_features <= 0) {
return label_idx;
}
str = Common::Trim(str); str = Common::Trim(str);
auto pos_space = str.find_first_of(" \f\n\r\t\v"); auto pos_space = str.find_first_of(" \f\n\r\t\v");
auto pos_colon = str.find_first_of(":"); auto pos_colon = str.find_first_of(":");
if (pos_colon == std::string::npos || pos_colon > pos_space) { if (pos_colon == std::string::npos || pos_colon > pos_space) {
return true; return -1;
} else { } else {
return false; return label_idx;
} }
} }
bool CheckHasLabelForTSV(std::string& str, int num_features) { int GetLabelIdxForTSV(std::string& str, int num_features, int label_idx) {
if (num_features <= 0) {
return label_idx;
}
str = Common::Trim(str); str = Common::Trim(str);
auto tokens = Common::Split(str.c_str(), '\t'); auto tokens = Common::Split(str.c_str(), '\t');
if (static_cast<int>(tokens.size()) == num_features) { if (static_cast<int>(tokens.size()) == num_features) {
return false; return -1;
} else { } else {
return true; return label_idx;
} }
} }
bool CheckHasLabelForCSV(std::string& str, int num_features) { int GetLabelIdxForCSV(std::string& str, int num_features, int label_idx) {
if (num_features <= 0) {
return label_idx;
}
str = Common::Trim(str); str = Common::Trim(str);
auto tokens = Common::Split(str.c_str(), ','); auto tokens = Common::Split(str.c_str(), ',');
if (static_cast<int>(tokens.size()) == num_features) { if (static_cast<int>(tokens.size()) == num_features) {
return false; return -1;
} else { } else {
return true; return label_idx;
} }
} }
enum DataType {
INVALID,
CSV,
TSV,
LIBSVM
};
Parser* Parser::CreateParser(const char* filename, bool has_header, int num_features, int label_idx) { Parser* Parser::CreateParser(const char* filename, bool has_header, int num_features, int label_idx) {
std::ifstream tmp_file; std::ifstream tmp_file;
tmp_file.open(filename); tmp_file.open(filename);
...@@ -80,46 +97,46 @@ Parser* Parser::CreateParser(const char* filename, bool has_header, int num_feat ...@@ -80,46 +97,46 @@ Parser* Parser::CreateParser(const char* filename, bool has_header, int num_feat
// Get some statistic from 2 line // Get some statistic from 2 line
GetStatistic(line1.c_str(), &comma_cnt, &tab_cnt, &colon_cnt); GetStatistic(line1.c_str(), &comma_cnt, &tab_cnt, &colon_cnt);
GetStatistic(line2.c_str(), &comma_cnt2, &tab_cnt2, &colon_cnt2); GetStatistic(line2.c_str(), &comma_cnt2, &tab_cnt2, &colon_cnt2);
Parser* ret = nullptr;
bool has_label = true;
DataType type = DataType::INVALID;
if (line2.size() == 0) { if (line2.size() == 0) {
// if only have one line on file // if only have one line on file
if (colon_cnt > 0) { if (colon_cnt > 0) {
if (num_features > 0) { type = DataType::LIBSVM;
has_label = CheckHasLabelForLibsvm(line1);
}
ret = new LibSVMParser(has_label ? label_idx : -1);
} else if (tab_cnt > 0) { } else if (tab_cnt > 0) {
if (num_features > 0 ) { type = DataType::TSV;
has_label = CheckHasLabelForTSV(line1, num_features);
}
ret = new TSVParser(has_label ? label_idx : -1);
} else if (comma_cnt > 0) { } else if (comma_cnt > 0) {
if (num_features > 0) { type = DataType::CSV;
has_label = CheckHasLabelForCSV(line1, num_features); }
}
ret = new CSVParser(has_label ? label_idx : -1);
}
} else { } else {
if (colon_cnt > 0 || colon_cnt2 > 0) { if (colon_cnt > 0 || colon_cnt2 > 0) {
if (num_features > 0) { type = DataType::LIBSVM;
has_label = CheckHasLabelForLibsvm(line1); } else if (tab_cnt == tab_cnt2 && tab_cnt > 0) {
} type = DataType::TSV;
ret = new LibSVMParser(has_label ? label_idx : -1);
}
else if (tab_cnt == tab_cnt2 && tab_cnt > 0) {
if (num_features > 0) {
has_label = CheckHasLabelForTSV(line1, num_features);
}
ret = new TSVParser(has_label ? label_idx : -1);
} else if (comma_cnt == comma_cnt2 && comma_cnt > 0) { } else if (comma_cnt == comma_cnt2 && comma_cnt > 0) {
if (num_features > 0) { type = DataType::CSV;
has_label = CheckHasLabelForCSV(line1, num_features);
}
ret = new CSVParser(has_label ? label_idx : -1);
} }
} }
if (!has_label) { if (type == DataType::INVALID) {
Log::Fatal("Unkown format of training data");
}
Parser* ret = nullptr;
if (type == DataType::LIBSVM) {
label_idx = GetLabelIdxForLibsvm(line1, num_features, label_idx);
ret = new LibSVMParser(label_idx);
}
else if (type == DataType::TSV) {
label_idx = GetLabelIdxForTSV(line1, num_features, label_idx);
ret = new TSVParser(label_idx);
}
else if (type == DataType::CSV) {
label_idx = GetLabelIdxForCSV(line1, num_features, label_idx);
ret = new CSVParser(label_idx);
}
if (label_idx < 0) {
Log::Info("Data file: %s doesn't contain label column", filename); Log::Info("Data file: %s doesn't contain label column", filename);
} }
return ret; return ret;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment