"src/git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "28481763357e83ef67ecc0d650c3ad94fe7c8776"
Commit aee30126 authored by Guolin Ke's avatar Guolin Ke
Browse files

not need for the data_has_label parameters any more.

parent 728875b4
...@@ -4,5 +4,3 @@ task = predict ...@@ -4,5 +4,3 @@ task = predict
data = binary.test data = binary.test
input_model= LightGBM_model.txt input_model= LightGBM_model.txt
data_has_label = true
...@@ -4,5 +4,3 @@ task = predict ...@@ -4,5 +4,3 @@ task = predict
data = rank.test data = rank.test
input_model= LightGBM_model.txt input_model= LightGBM_model.txt
data_has_label = true
...@@ -5,4 +5,3 @@ data = binary.test ...@@ -5,4 +5,3 @@ data = binary.test
input_model= LightGBM_model.txt input_model= LightGBM_model.txt
data_has_label = true
...@@ -4,5 +4,3 @@ task = predict ...@@ -4,5 +4,3 @@ task = predict
data = regression.test data = regression.test
input_model= LightGBM_model.txt input_model= LightGBM_model.txt
data_has_label = true
...@@ -88,7 +88,6 @@ public: ...@@ -88,7 +88,6 @@ public:
int max_bin = 255; int max_bin = 255;
int data_random_seed = 1; int data_random_seed = 1;
std::string data_filename = ""; std::string data_filename = "";
bool data_has_label = true;
std::vector<std::string> valid_data_filenames; std::vector<std::string> valid_data_filenames;
std::string output_model = "LightGBM_model.txt"; std::string output_model = "LightGBM_model.txt";
std::string output_result = "LightGBM_predict_result.txt"; std::string output_result = "LightGBM_predict_result.txt";
...@@ -274,8 +273,6 @@ struct ParameterAlias { ...@@ -274,8 +273,6 @@ struct ParameterAlias {
{ "app", "objective" }, { "app", "objective" },
{ "train_data", "data" }, { "train_data", "data" },
{ "train", "data" }, { "train", "data" },
{ "has_label", "data_has_label" },
{ "is_data_has_label", "data_has_label" },
{ "model_output", "output_model" }, { "model_output", "output_model" },
{ "model_out", "output_model" }, { "model_out", "output_model" },
{ "model_input", "input_model" }, { "model_input", "input_model" },
......
...@@ -208,9 +208,11 @@ public: ...@@ -208,9 +208,11 @@ public:
/*! /*!
* \brief Create a object of parser, will auto choose the format depend on file * \brief Create a object of parser, will auto choose the format depend on file
* \param filename One Filename of data * \param filename One Filename of data
* \param num_features Pass num_features of this data file if you know, <=0 means don't know
* \param has_label output, if num_features > 0, will output this data has label or not
* \return Object of parser * \return Object of parser
*/ */
static Parser* CreateParser(const char* filename); static Parser* CreateParser(const char* filename, int num_features, bool* has_label);
}; };
using PredictFunction = using PredictFunction =
...@@ -299,6 +301,9 @@ public: ...@@ -299,6 +301,9 @@ public:
/*! \brief Get Number of used features */ /*! \brief Get Number of used features */
inline int num_features() const { return num_features_; } inline int num_features() const { return num_features_; }
/*! \brief Get Number of total features */
inline int num_total_features() const { return num_total_features_; }
/*! \brief Get Number of data */ /*! \brief Get Number of data */
inline data_size_t num_data() const { return num_data_; } inline data_size_t num_data() const { return num_data_; }
...@@ -373,6 +378,8 @@ private: ...@@ -373,6 +378,8 @@ private:
std::vector<int> used_feature_map_; std::vector<int> used_feature_map_;
/*! \brief Number of used features*/ /*! \brief Number of used features*/
int num_features_; int num_features_;
/*! \brief Number of total features*/
int num_total_features_;
/*! \brief Number of total data*/ /*! \brief Number of total data*/
data_size_t num_data_; data_size_t num_data_;
/*! \brief Store some label level data*/ /*! \brief Store some label level data*/
......
...@@ -253,8 +253,7 @@ void Application::Train() { ...@@ -253,8 +253,7 @@ void Application::Train() {
void Application::Predict() { void Application::Predict() {
// create predictor // create predictor
Predictor predictor(boosting_, config_.io_config.is_sigmoid); Predictor predictor(boosting_, config_.io_config.is_sigmoid);
predictor.Predict(config_.io_config.data_filename.c_str(), predictor.Predict(config_.io_config.data_filename.c_str(), config_.io_config.output_result.c_str());
config_.io_config.data_has_label, config_.io_config.output_result.c_str());
Log::Stdout("finish predict"); Log::Stdout("finish predict");
} }
......
...@@ -96,7 +96,7 @@ public: ...@@ -96,7 +96,7 @@ public:
* \param has_label True if this data contains label * \param has_label True if this data contains label
* \param result_filename Filename of output result * \param result_filename Filename of output result
*/ */
void Predict(const char* data_filename, bool has_label, const char* result_filename) { void Predict(const char* data_filename, const char* result_filename) {
FILE* result_file; FILE* result_file;
#ifdef _MSC_VER #ifdef _MSC_VER
...@@ -108,8 +108,8 @@ public: ...@@ -108,8 +108,8 @@ public:
if (result_file == NULL) { if (result_file == NULL) {
Log::Stderr("predition result file %s doesn't exists", data_filename); Log::Stderr("predition result file %s doesn't exists", data_filename);
} }
bool has_label = false;
Parser* parser = Parser::CreateParser(data_filename); Parser* parser = Parser::CreateParser(data_filename, num_features_, &has_label);
if (parser == nullptr) { if (parser == nullptr) {
Log::Stderr("can regonise input data format, filename %s", data_filename); Log::Stderr("can regonise input data format, filename %s", data_filename);
......
...@@ -60,10 +60,7 @@ void GBDT::Init(const Dataset* train_data, const ObjectiveFunction* object_funct ...@@ -60,10 +60,7 @@ void GBDT::Init(const Dataset* train_data, const ObjectiveFunction* object_funct
hessians_ = new score_t[num_data_]; hessians_ = new score_t[num_data_];
// get max feature index // get max feature index
for (int i = 0; i < train_data->num_features(); ++i) { max_feature_idx_ = train_data_->num_total_features() - 1;
max_feature_idx_ = Common::Max<int>(max_feature_idx_,
train_data->FeatureAt(i)->feature_index());
}
// if need bagging, create buffer // if need bagging, create buffer
if (gbdt_config_->bagging_fraction < 1.0 && gbdt_config_->bagging_freq > 0) { if (gbdt_config_->bagging_fraction < 1.0 && gbdt_config_->bagging_freq > 0) {
......
...@@ -15,11 +15,6 @@ void OverallConfig::Set(const std::unordered_map<std::string, std::string>& para ...@@ -15,11 +15,6 @@ void OverallConfig::Set(const std::unordered_map<std::string, std::string>& para
GetInt(params, "num_threads", &num_threads); GetInt(params, "num_threads", &num_threads);
GetTaskType(params); GetTaskType(params);
// prediction task, default not has label
if (task_type == TaskType::kPredict) {
io_config.data_has_label = false;
}
GetBoostingType(params); GetBoostingType(params);
GetObjectiveType(params); GetObjectiveType(params);
GetMetricType(params); GetMetricType(params);
...@@ -125,11 +120,6 @@ void OverallConfig::CheckParamConflict() { ...@@ -125,11 +120,6 @@ void OverallConfig::CheckParamConflict() {
TreeLearnerType::kDataParallelTreeLearner) { TreeLearnerType::kDataParallelTreeLearner) {
is_parallel_find_bin = true; is_parallel_find_bin = true;
} }
if (task_type == TaskType::kTrain && io_config.data_has_label == false) {
Log::Stderr("Data should have label in training task");
}
} }
void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) { void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
...@@ -141,7 +131,6 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) { ...@@ -141,7 +131,6 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
Log::Stderr("No training/prediction data, application quit"); Log::Stderr("No training/prediction data, application quit");
} }
GetInt(params, "num_model_predict", &num_model_predict); GetInt(params, "num_model_predict", &num_model_predict);
GetBool(params, "data_has_label", &data_has_label);
GetBool(params, "is_pre_partition", &is_pre_partition); GetBool(params, "is_pre_partition", &is_pre_partition);
GetBool(params, "is_enable_sparse", &is_enable_sparse); GetBool(params, "is_enable_sparse", &is_enable_sparse);
GetBool(params, "use_two_round_loading", &use_two_round_loading); GetBool(params, "use_two_round_loading", &use_two_round_loading);
......
...@@ -29,7 +29,7 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename, ...@@ -29,7 +29,7 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
// load weight, query information and initilize score // load weight, query information and initilize score
metadata_.Init(data_filename, init_score_filename); metadata_.Init(data_filename, init_score_filename);
// create text parser // create text parser
parser_ = Parser::CreateParser(data_filename_); parser_ = Parser::CreateParser(data_filename_, 0, nullptr);
if (parser_ == nullptr) { if (parser_ == nullptr) {
Log::Stderr("cannot recognise input data format, filename: %s", data_filename_); Log::Stderr("cannot recognise input data format, filename: %s", data_filename_);
} }
...@@ -189,7 +189,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector< ...@@ -189,7 +189,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
// -1 means doesn't use this feature // -1 means doesn't use this feature
used_feature_map_ = std::vector<int>(sample_values.size(), -1); used_feature_map_ = std::vector<int>(sample_values.size(), -1);
num_total_features_ = sample_values.size();
// start find bins // start find bins
if (num_machines == 1) { if (num_machines == 1) {
std::vector<BinMapper*> bin_mappers(sample_values.size()); std::vector<BinMapper*> bin_mappers(sample_values.size());
...@@ -209,6 +209,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector< ...@@ -209,6 +209,7 @@ void Dataset::ConstructBinMappers(int rank, int num_machines, const std::vector<
num_data_, is_enable_sparse_)); num_data_, is_enable_sparse_));
} else { } else {
// if feature is trival(only 1 bin), free spaces // if feature is trival(only 1 bin), free spaces
Log::Stdout("Warning: feture %d only contains one value, will ignore it", i);
delete bin_mappers[i]; delete bin_mappers[i];
} }
} }
......
...@@ -20,7 +20,38 @@ void GetStatistic(const char* str, int* comma_cnt, int* tab_cnt, int* colon_cnt) ...@@ -20,7 +20,38 @@ void GetStatistic(const char* str, int* comma_cnt, int* tab_cnt, int* colon_cnt)
} }
} }
Parser* Parser::CreateParser(const char* filename) { bool CheckHasLabelForLibsvm(std::string& str) {
str = Common::Trim(str);
auto pos_space = str.find_first_of(" \f\n\r\t\v");
auto pos_colon = str.find_first_of(":");
if (pos_colon == std::string::npos || pos_colon > pos_space) {
return true;
} else {
return false;
}
}
bool CheckHasLabelForTSV(std::string& str, int num_features) {
str = Common::Trim(str);
auto tokens = Common::Split(str.c_str(), '\t');
if (tokens.size() == num_features) {
return false;
} else {
return true;
}
}
bool CheckHasLabelForCSV(std::string& str, int num_features) {
str = Common::Trim(str);
auto tokens = Common::Split(str.c_str(), ',');
if (tokens.size() == num_features) {
return false;
} else {
return true;
}
}
Parser* Parser::CreateParser(const char* filename, int num_features, bool* has_label) {
std::ifstream tmp_file; std::ifstream tmp_file;
tmp_file.open(filename); tmp_file.open(filename);
if (!tmp_file.is_open()) { if (!tmp_file.is_open()) {
...@@ -44,29 +75,45 @@ Parser* Parser::CreateParser(const char* filename) { ...@@ -44,29 +75,45 @@ Parser* Parser::CreateParser(const char* filename) {
// Get some statistic from 2 line // Get some statistic from 2 line
GetStatistic(line1.c_str(), &comma_cnt, &tab_cnt, &colon_cnt); GetStatistic(line1.c_str(), &comma_cnt, &tab_cnt, &colon_cnt);
GetStatistic(line2.c_str(), &comma_cnt2, &tab_cnt2, &colon_cnt2); GetStatistic(line2.c_str(), &comma_cnt2, &tab_cnt2, &colon_cnt2);
Parser* ret = nullptr;
if (line2.size() == 0) { if (line2.size() == 0) {
// if only have one line on file // if only have one line on file
if (colon_cnt > 0) { if (colon_cnt > 0) {
return new LibSVMParser(); ret = new LibSVMParser();
if (num_features > 0 && has_label != nullptr) {
*has_label = CheckHasLabelForLibsvm(line1);
}
} else if (tab_cnt > 0) { } else if (tab_cnt > 0) {
return new TSVParser(); ret = new TSVParser();
if (num_features > 0 && has_label != nullptr) {
*has_label = CheckHasLabelForTSV(line1, num_features);
}
} else if (comma_cnt > 0) { } else if (comma_cnt > 0) {
return new CSVParser(); ret = new CSVParser();
} else { if (num_features > 0 && has_label != nullptr) {
return nullptr; *has_label = CheckHasLabelForCSV(line1, num_features);
}
} }
} else { } else {
if (colon_cnt > 0 || colon_cnt2 > 0) { if (colon_cnt > 0 || colon_cnt2 > 0) {
return new LibSVMParser(); ret = new LibSVMParser();
if (num_features > 0 && has_label != nullptr) {
*has_label = CheckHasLabelForLibsvm(line1);
}
} }
else if (tab_cnt == tab_cnt2 && tab_cnt > 0) { else if (tab_cnt == tab_cnt2 && tab_cnt > 0) {
return new TSVParser(); ret = new TSVParser();
if (num_features > 0 && has_label != nullptr) {
*has_label = CheckHasLabelForTSV(line1, num_features);
}
} else if (comma_cnt == comma_cnt2 && comma_cnt > 0) { } else if (comma_cnt == comma_cnt2 && comma_cnt > 0) {
return new CSVParser(); ret = new CSVParser();
} else { if (num_features > 0 && has_label != nullptr) {
return nullptr; *has_label = CheckHasLabelForCSV(line1, num_features);
}
} }
} }
return ret;
} }
} // namespace LightGBM } // namespace LightGBM
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment