"git@developer.sourcefind.cn:tianlh/lightgbm-dcu.git" did not exist on "a9d9b1199ce81cf307bbbbe2b4e5dbc76dc048b3"
Commit b953cd58 authored by Guolin Ke's avatar Guolin Ke
Browse files

refine dataset

parent b1b2f181
...@@ -30,9 +30,7 @@ DllExport const char* LGBM_GetLastError(); ...@@ -30,9 +30,7 @@ DllExport const char* LGBM_GetLastError();
/*! /*!
* \brief load data set from file like the command_line LightGBM do * \brief load data set from file like the command_line LightGBM do
* \param parameters additional parameters: * \param parameters additional parameters
has_header, label_column, weight_column, group_column, ignore_column
use format like 'has_header=true label_column=1 '..
* \param filename the name of the file * \param filename the name of the file
* \param reference used to align bin mapper with other dataset, nullptr means don't used * \param reference used to align bin mapper with other dataset, nullptr means don't used
* \param out a loaded dataset * \param out a loaded dataset
...@@ -60,6 +58,7 @@ DllExport int LGBM_CreateDatasetFromBinaryFile(const char* filename, ...@@ -60,6 +58,7 @@ DllExport int LGBM_CreateDatasetFromBinaryFile(const char* filename,
* \param nindptr number of rows in the matix + 1 * \param nindptr number of rows in the matix + 1
* \param nelem number of nonzero elements in the matrix * \param nelem number of nonzero elements in the matrix
* \param num_col number of columns; when it's set to 0, then guess from data * \param num_col number of columns; when it's set to 0, then guess from data
* \param parameters additional parameters
* \param reference used to align bin mapper with other dataset, nullptr means don't used * \param reference used to align bin mapper with other dataset, nullptr means don't used
* \param out created dataset * \param out created dataset
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
...@@ -70,6 +69,7 @@ DllExport int LGBM_CreateDatasetFromCSR(const uint64_t* indptr, ...@@ -70,6 +69,7 @@ DllExport int LGBM_CreateDatasetFromCSR(const uint64_t* indptr,
uint64_t nindptr, uint64_t nindptr,
uint64_t nelem, uint64_t nelem,
uint64_t num_col, uint64_t num_col,
const char* parameters,
const DatesetHandle* reference, const DatesetHandle* reference,
DatesetHandle* out); DatesetHandle* out);
...@@ -81,6 +81,7 @@ DllExport int LGBM_CreateDatasetFromCSR(const uint64_t* indptr, ...@@ -81,6 +81,7 @@ DllExport int LGBM_CreateDatasetFromCSR(const uint64_t* indptr,
* \param nindptr number of rows in the matix + 1 * \param nindptr number of rows in the matix + 1
* \param nelem number of nonzero elements in the matrix * \param nelem number of nonzero elements in the matrix
* \param num_row number of rows; when it's set to 0, then guess from data * \param num_row number of rows; when it's set to 0, then guess from data
* \param parameters additional parameters
* \param reference used to align bin mapper with other dataset, nullptr means don't used * \param reference used to align bin mapper with other dataset, nullptr means don't used
* \param out created dataset * \param out created dataset
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
...@@ -91,6 +92,7 @@ DllExport int LGBM_CreateDatasetFromCSC(const uint64_t* col_ptr, ...@@ -91,6 +92,7 @@ DllExport int LGBM_CreateDatasetFromCSC(const uint64_t* col_ptr,
uint64_t nindptr, uint64_t nindptr,
uint64_t nelem, uint64_t nelem,
uint64_t num_row, uint64_t num_row,
const char* parameters,
const DatesetHandle* reference, const DatesetHandle* reference,
DatesetHandle* out); DatesetHandle* out);
...@@ -100,6 +102,7 @@ DllExport int LGBM_CreateDatasetFromCSC(const uint64_t* col_ptr, ...@@ -100,6 +102,7 @@ DllExport int LGBM_CreateDatasetFromCSC(const uint64_t* col_ptr,
* \param nrow number of rows * \param nrow number of rows
* \param ncol number columns * \param ncol number columns
* \param missing which value to represent missing value * \param missing which value to represent missing value
* \param parameters additional parameters
* \param reference used to align bin mapper with other dataset, nullptr means don't used * \param reference used to align bin mapper with other dataset, nullptr means don't used
* \param out created dataset * \param out created dataset
* \return 0 when success, -1 when failure happens * \return 0 when success, -1 when failure happens
...@@ -108,6 +111,7 @@ DllExport int LGBM_CreateDatasetFromMat(const float* data, ...@@ -108,6 +111,7 @@ DllExport int LGBM_CreateDatasetFromMat(const float* data,
uint64_t nrow, uint64_t nrow,
uint64_t ncol, uint64_t ncol,
float missing, float missing,
const char* parameters,
const DatesetHandle* reference, const DatesetHandle* reference,
DatesetHandle* out); DatesetHandle* out);
......
...@@ -99,6 +99,8 @@ public: ...@@ -99,6 +99,8 @@ public:
bool is_enable_sparse = true; bool is_enable_sparse = true;
bool use_two_round_loading = false; bool use_two_round_loading = false;
bool is_save_binary_file = false; bool is_save_binary_file = false;
bool enable_load_from_binary_file = true;
int bin_construct_sample_cnt = 50000;
bool is_sigmoid = true; bool is_sigmoid = true;
bool has_header = false; bool has_header = false;
......
...@@ -262,6 +262,13 @@ public: ...@@ -262,6 +262,13 @@ public:
: Dataset(data_filename, "", io_config, predict_fun) { : Dataset(data_filename, "", io_config, predict_fun) {
} }
/*!
* \brief Constructor, without filename, used to load data from memory
* \param io_config configs for IO
* \param predict_fun Used for initial model, will give a prediction score based on this function, then set as initial score
*/
Dataset(const IOConfig& io_config, const PredictFunction& predict_fun);
/*! \brief Destructor */ /*! \brief Destructor */
~Dataset(); ~Dataset();
...@@ -290,10 +297,19 @@ public: ...@@ -290,10 +297,19 @@ public:
*/ */
void LoadValidationData(const Dataset* train_set, bool use_two_round_loading); void LoadValidationData(const Dataset* train_set, bool use_two_round_loading);
/*!
* \brief Load data set from binary file
* \param bin_filename filename of bin data
* \param rank Rank of local machine
* \param num_machines Total number of all machines
* \param is_pre_partition True if data file is pre-partitioned
*/
void LoadDataFromBinFile(const char* bin_filename, int rank, int num_machines, bool is_pre_partition);
/*! /*!
* \brief Save current dataset into binary file, will save to "filename.bin" * \brief Save current dataset into binary file, will save to "filename.bin"
*/ */
void SaveBinaryFile(); void SaveBinaryFile(const char* bin_filename);
/*! /*!
* \brief Get a feature pointer for specific index * \brief Get a feature pointer for specific index
...@@ -371,14 +387,6 @@ private: ...@@ -371,14 +387,6 @@ private:
/*! \brief Check can load from binary file */ /*! \brief Check can load from binary file */
void CheckCanLoadFromBin(); void CheckCanLoadFromBin();
/*!
* \brief Load data set from binary file
* \param rank Rank of local machine
* \param num_machines Total number of all machines
* \param is_pre_partition True if data file is pre-partitioned
*/
void LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partition);
/*! \brief Check this data set is null or not */ /*! \brief Check this data set is null or not */
void CheckDataset(); void CheckDataset();
...@@ -424,6 +432,8 @@ private: ...@@ -424,6 +432,8 @@ private:
std::unordered_set<int> ignore_features_; std::unordered_set<int> ignore_features_;
/*! \brief store feature names */ /*! \brief store feature names */
std::vector<std::string> feature_names_; std::vector<std::string> feature_names_;
/*! \brief store feature names */
int bin_construct_sample_cnt_;
}; };
} // namespace LightGBM } // namespace LightGBM
......
...@@ -150,7 +150,7 @@ void Application::LoadData() { ...@@ -150,7 +150,7 @@ void Application::LoadData() {
} }
// need save binary file // need save binary file
if (config_.io_config.is_save_binary_file) { if (config_.io_config.is_save_binary_file) {
train_data_->SaveBinaryFile(); train_data_->SaveBinaryFile(nullptr);
} }
// create training metric // create training metric
if (config_.boosting_config->is_provide_training_metric) { if (config_.boosting_config->is_provide_training_metric) {
...@@ -175,7 +175,7 @@ void Application::LoadData() { ...@@ -175,7 +175,7 @@ void Application::LoadData() {
config_.io_config.use_two_round_loading); config_.io_config.use_two_round_loading);
// need save binary file // need save binary file
if (config_.io_config.is_save_binary_file) { if (config_.io_config.is_save_binary_file) {
valid_datas_.back()->SaveBinaryFile(); valid_datas_.back()->SaveBinaryFile(nullptr);
} }
// add metric for validation data // add metric for validation data
......
...@@ -25,7 +25,7 @@ void LoadFileToBoosting(Boosting* boosting, const char* filename) { ...@@ -25,7 +25,7 @@ void LoadFileToBoosting(Boosting* boosting, const char* filename) {
} }
Boosting* Boosting::CreateBoosting(BoostingType type, const char* filename) { Boosting* Boosting::CreateBoosting(BoostingType type, const char* filename) {
if (filename[0] == '\0') { if (filename == nullptr || filename[0] == '\0') {
if (type == BoostingType::kGBDT) { if (type == BoostingType::kGBDT) {
return new GBDT(); return new GBDT();
} else { } else {
......
...@@ -191,10 +191,12 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) { ...@@ -191,10 +191,12 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
} }
GetInt(params, "verbose", &verbosity); GetInt(params, "verbose", &verbosity);
GetInt(params, "num_model_predict", &num_model_predict); GetInt(params, "num_model_predict", &num_model_predict);
GetInt(params, "bin_construct_sample_cnt", &bin_construct_sample_cnt);
GetBool(params, "is_pre_partition", &is_pre_partition); GetBool(params, "is_pre_partition", &is_pre_partition);
GetBool(params, "is_enable_sparse", &is_enable_sparse); GetBool(params, "is_enable_sparse", &is_enable_sparse);
GetBool(params, "use_two_round_loading", &use_two_round_loading); GetBool(params, "use_two_round_loading", &use_two_round_loading);
GetBool(params, "is_save_binary_file", &is_save_binary_file); GetBool(params, "is_save_binary_file", &is_save_binary_file);
GetBool(params, "enable_load_from_binary_file", &enable_load_from_binary_file);
GetBool(params, "is_sigmoid", &is_sigmoid); GetBool(params, "is_sigmoid", &is_sigmoid);
GetString(params, "output_model", &output_model); GetString(params, "output_model", &output_model);
GetString(params, "input_model", &input_model); GetString(params, "input_model", &input_model);
......
...@@ -18,9 +18,11 @@ namespace LightGBM { ...@@ -18,9 +18,11 @@ namespace LightGBM {
Dataset::Dataset(const char* data_filename, const char* init_score_filename, Dataset::Dataset(const char* data_filename, const char* init_score_filename,
const IOConfig& io_config, const PredictFunction& predict_fun) const IOConfig& io_config, const PredictFunction& predict_fun)
:data_filename_(data_filename), random_(io_config.data_random_seed), :data_filename_(data_filename), random_(io_config.data_random_seed),
max_bin_(io_config.max_bin), is_enable_sparse_(io_config.is_enable_sparse), predict_fun_(predict_fun) { max_bin_(io_config.max_bin), is_enable_sparse_(io_config.is_enable_sparse),
predict_fun_(predict_fun), bin_construct_sample_cnt_(io_config.bin_construct_sample_cnt) {
CheckCanLoadFromBin(); if (io_config.enable_load_from_binary_file) {
CheckCanLoadFromBin();
}
if (is_loading_from_binfile_ && predict_fun != nullptr) { if (is_loading_from_binfile_ && predict_fun != nullptr) {
Log::Info("Cannot performing initialization of prediction by using binary file, using text file instead"); Log::Info("Cannot performing initialization of prediction by using binary file, using text file instead");
is_loading_from_binfile_ = false; is_loading_from_binfile_ = false;
...@@ -160,6 +162,17 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename, ...@@ -160,6 +162,17 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
} }
Dataset::Dataset(const IOConfig& io_config, const PredictFunction& predict_fun)
:data_filename_(""), random_(io_config.data_random_seed),
max_bin_(io_config.max_bin), is_enable_sparse_(io_config.is_enable_sparse),
predict_fun_(predict_fun), bin_construct_sample_cnt_(io_config.bin_construct_sample_cnt) {
parser_ = nullptr;
text_reader_ = nullptr;
}
Dataset::~Dataset() { Dataset::~Dataset() {
if (parser_ != nullptr) { delete parser_; } if (parser_ != nullptr) { delete parser_; }
if (text_reader_ != nullptr) { delete text_reader_; } if (text_reader_ != nullptr) { delete text_reader_; }
...@@ -216,7 +229,7 @@ void Dataset::LoadDataToMemory(int rank, int num_machines, bool is_pre_partition ...@@ -216,7 +229,7 @@ void Dataset::LoadDataToMemory(int rank, int num_machines, bool is_pre_partition
} }
void Dataset::SampleDataFromMemory(std::vector<std::string>* out_data) { void Dataset::SampleDataFromMemory(std::vector<std::string>* out_data) {
const size_t sample_cnt = static_cast<size_t>(num_data_ < 50000 ? num_data_ : 50000); const size_t sample_cnt = static_cast<size_t>(num_data_ < bin_construct_sample_cnt_ ? num_data_ : bin_construct_sample_cnt_);
std::vector<size_t> sample_indices = random_.Sample(num_data_, sample_cnt); std::vector<size_t> sample_indices = random_.Sample(num_data_, sample_cnt);
out_data->clear(); out_data->clear();
for (size_t i = 0; i < sample_indices.size(); ++i) { for (size_t i = 0; i < sample_indices.size(); ++i) {
...@@ -228,7 +241,7 @@ void Dataset::SampleDataFromMemory(std::vector<std::string>* out_data) { ...@@ -228,7 +241,7 @@ void Dataset::SampleDataFromMemory(std::vector<std::string>* out_data) {
void Dataset::SampleDataFromFile(int rank, int num_machines, bool is_pre_partition, void Dataset::SampleDataFromFile(int rank, int num_machines, bool is_pre_partition,
std::vector<std::string>* out_data) { std::vector<std::string>* out_data) {
used_data_indices_.clear(); used_data_indices_.clear();
const size_t sample_cnt = 50000; const data_size_t sample_cnt = static_cast<data_size_t>(bin_construct_sample_cnt_);
if (num_machines == 1 || is_pre_partition) { if (num_machines == 1 || is_pre_partition) {
num_data_ = static_cast<data_size_t>(text_reader_->SampleFromFile(random_, sample_cnt, out_data)); num_data_ = static_cast<data_size_t>(text_reader_->SampleFromFile(random_, sample_cnt, out_data));
global_num_data_ = num_data_; global_num_data_ = num_data_;
...@@ -452,8 +465,10 @@ void Dataset::LoadTrainData(int rank, int num_machines, bool is_pre_partition, b ...@@ -452,8 +465,10 @@ void Dataset::LoadTrainData(int rank, int num_machines, bool is_pre_partition, b
ExtractFeaturesFromFile(); ExtractFeaturesFromFile();
} }
} else { } else {
std::string bin_filename(data_filename_);
bin_filename.append(".bin");
// load data from binary file // load data from binary file
LoadDataFromBinFile(rank, num_machines, is_pre_partition); LoadDataFromBinFile(bin_filename.c_str(), rank, num_machines, is_pre_partition);
} }
// check meta data // check meta data
metadata_.CheckOrPartition(static_cast<data_size_t>(global_num_data_), used_data_indices_); metadata_.CheckOrPartition(static_cast<data_size_t>(global_num_data_), used_data_indices_);
...@@ -497,8 +512,10 @@ void Dataset::LoadValidationData(const Dataset* train_set, bool use_two_round_lo ...@@ -497,8 +512,10 @@ void Dataset::LoadValidationData(const Dataset* train_set, bool use_two_round_lo
ExtractFeaturesFromFile(); ExtractFeaturesFromFile();
} }
} else { } else {
std::string bin_filename(data_filename_);
bin_filename.append(".bin");
// load from binary file // load from binary file
LoadDataFromBinFile(0, 1, false); LoadDataFromBinFile(bin_filename.c_str(), 0, 1, false);
} }
// not need to check validation data // not need to check validation data
// check meta data // check meta data
...@@ -646,19 +663,23 @@ void Dataset::ExtractFeaturesFromFile() { ...@@ -646,19 +663,23 @@ void Dataset::ExtractFeaturesFromFile() {
} }
} }
void Dataset::SaveBinaryFile() { void Dataset::SaveBinaryFile(const char* bin_filename) {
// if is loaded from binary file, not need to save
if (!is_loading_from_binfile_) { if (!is_loading_from_binfile_) {
std::string bin_filename(data_filename_); // if not pass a filename, just append ".bin" of original file
bin_filename.append(".bin"); if (bin_filename == nullptr || bin_filename[0] == '\0') {
std::string bin_filename_str(data_filename_);
bin_filename_str.append(".bin");
bin_filename = bin_filename_str.c_str();
}
FILE* file; FILE* file;
#ifdef _MSC_VER #ifdef _MSC_VER
fopen_s(&file, bin_filename.c_str(), "wb"); fopen_s(&file, bin_filename, "wb");
#else #else
file = fopen(bin_filename.c_str(), "wb"); file = fopen(bin_filename, "wb");
#endif #endif
if (file == NULL) { if (file == NULL) {
Log::Fatal("Cannot write binary data to %s ", bin_filename.c_str()); Log::Fatal("Cannot write binary data to %s ", bin_filename);
} }
Log::Info("Saving data to binary file: %s", data_filename_); Log::Info("Saving data to binary file: %s", data_filename_);
...@@ -715,20 +736,18 @@ void Dataset::CheckCanLoadFromBin() { ...@@ -715,20 +736,18 @@ void Dataset::CheckCanLoadFromBin() {
} }
} }
void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partition) { void Dataset::LoadDataFromBinFile(const char* bin_filename, int rank, int num_machines, bool is_pre_partition) {
std::string bin_filename(data_filename_);
bin_filename.append(".bin");
FILE* file; FILE* file;
#ifdef _MSC_VER #ifdef _MSC_VER
fopen_s(&file, bin_filename.c_str(), "rb"); fopen_s(&file, bin_filename, "rb");
#else #else
file = fopen(bin_filename.c_str(), "rb"); file = fopen(bin_filename, "rb");
#endif #endif
if (file == NULL) { if (file == NULL) {
Log::Fatal("Cannot read binary data from %s", bin_filename.c_str()); Log::Fatal("Cannot read binary data from %s", bin_filename);
} }
// buffer to read binary file // buffer to read binary file
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment