Commit b953cd58 authored by Guolin Ke's avatar Guolin Ke
Browse files

refine dataset

parent b1b2f181
......@@ -30,9 +30,7 @@ DllExport const char* LGBM_GetLastError();
/*!
* \brief load data set from file like the command_line LightGBM do
* \param parameters additional parameters:
has_header, label_column, weight_column, group_column, ignore_column
use format like 'has_header=true label_column=1 '..
* \param parameters additional parameters
* \param filename the name of the file
* \param reference used to align bin mapper with other dataset, nullptr means don't used
* \param out a loaded dataset
......@@ -60,6 +58,7 @@ DllExport int LGBM_CreateDatasetFromBinaryFile(const char* filename,
* \param nindptr number of rows in the matix + 1
* \param nelem number of nonzero elements in the matrix
* \param num_col number of columns; when it's set to 0, then guess from data
* \param parameters additional parameters
* \param reference used to align bin mapper with other dataset, nullptr means don't used
* \param out created dataset
* \return 0 when success, -1 when failure happens
......@@ -70,6 +69,7 @@ DllExport int LGBM_CreateDatasetFromCSR(const uint64_t* indptr,
uint64_t nindptr,
uint64_t nelem,
uint64_t num_col,
const char* parameters,
const DatesetHandle* reference,
DatesetHandle* out);
......@@ -81,6 +81,7 @@ DllExport int LGBM_CreateDatasetFromCSR(const uint64_t* indptr,
* \param nindptr number of rows in the matix + 1
* \param nelem number of nonzero elements in the matrix
* \param num_row number of rows; when it's set to 0, then guess from data
* \param parameters additional parameters
* \param reference used to align bin mapper with other dataset, nullptr means don't used
* \param out created dataset
* \return 0 when success, -1 when failure happens
......@@ -91,6 +92,7 @@ DllExport int LGBM_CreateDatasetFromCSC(const uint64_t* col_ptr,
uint64_t nindptr,
uint64_t nelem,
uint64_t num_row,
const char* parameters,
const DatesetHandle* reference,
DatesetHandle* out);
......@@ -100,6 +102,7 @@ DllExport int LGBM_CreateDatasetFromCSC(const uint64_t* col_ptr,
* \param nrow number of rows
* \param ncol number columns
* \param missing which value to represent missing value
* \param parameters additional parameters
* \param reference used to align bin mapper with other dataset, nullptr means don't used
* \param out created dataset
* \return 0 when success, -1 when failure happens
......@@ -108,6 +111,7 @@ DllExport int LGBM_CreateDatasetFromMat(const float* data,
uint64_t nrow,
uint64_t ncol,
float missing,
const char* parameters,
const DatesetHandle* reference,
DatesetHandle* out);
......
......@@ -99,6 +99,8 @@ public:
bool is_enable_sparse = true;
bool use_two_round_loading = false;
bool is_save_binary_file = false;
bool enable_load_from_binary_file = true;
int bin_construct_sample_cnt = 50000;
bool is_sigmoid = true;
bool has_header = false;
......
......@@ -262,6 +262,13 @@ public:
: Dataset(data_filename, "", io_config, predict_fun) {
}
/*!
* \brief Constructor, without filename, used to load data from memory
* \param io_config configs for IO
* \param predict_fun Used for initial model, will give a prediction score based on this function, then set as initial score
*/
Dataset(const IOConfig& io_config, const PredictFunction& predict_fun);
/*! \brief Destructor */
~Dataset();
......@@ -290,10 +297,19 @@ public:
*/
void LoadValidationData(const Dataset* train_set, bool use_two_round_loading);
/*!
* \brief Load data set from binary file
* \param bin_filename filename of bin data
* \param rank Rank of local machine
* \param num_machines Total number of all machines
* \param is_pre_partition True if data file is pre-partitioned
*/
void LoadDataFromBinFile(const char* bin_filename, int rank, int num_machines, bool is_pre_partition);
/*!
* \brief Save current dataset into binary file, will save to "filename.bin"
*/
void SaveBinaryFile();
void SaveBinaryFile(const char* bin_filename);
/*!
* \brief Get a feature pointer for specific index
......@@ -371,14 +387,6 @@ private:
/*! \brief Check can load from binary file */
void CheckCanLoadFromBin();
/*!
* \brief Load data set from binary file
* \param rank Rank of local machine
* \param num_machines Total number of all machines
* \param is_pre_partition True if data file is pre-partitioned
*/
void LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partition);
/*! \brief Check this data set is null or not */
void CheckDataset();
......@@ -424,6 +432,8 @@ private:
std::unordered_set<int> ignore_features_;
/*! \brief store feature names */
std::vector<std::string> feature_names_;
/*! \brief store feature names */
int bin_construct_sample_cnt_;
};
} // namespace LightGBM
......
......@@ -150,7 +150,7 @@ void Application::LoadData() {
}
// need save binary file
if (config_.io_config.is_save_binary_file) {
train_data_->SaveBinaryFile();
train_data_->SaveBinaryFile(nullptr);
}
// create training metric
if (config_.boosting_config->is_provide_training_metric) {
......@@ -175,7 +175,7 @@ void Application::LoadData() {
config_.io_config.use_two_round_loading);
// need save binary file
if (config_.io_config.is_save_binary_file) {
valid_datas_.back()->SaveBinaryFile();
valid_datas_.back()->SaveBinaryFile(nullptr);
}
// add metric for validation data
......
......@@ -25,7 +25,7 @@ void LoadFileToBoosting(Boosting* boosting, const char* filename) {
}
Boosting* Boosting::CreateBoosting(BoostingType type, const char* filename) {
if (filename[0] == '\0') {
if (filename == nullptr || filename[0] == '\0') {
if (type == BoostingType::kGBDT) {
return new GBDT();
} else {
......
......@@ -191,10 +191,12 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
}
GetInt(params, "verbose", &verbosity);
GetInt(params, "num_model_predict", &num_model_predict);
GetInt(params, "bin_construct_sample_cnt", &bin_construct_sample_cnt);
GetBool(params, "is_pre_partition", &is_pre_partition);
GetBool(params, "is_enable_sparse", &is_enable_sparse);
GetBool(params, "use_two_round_loading", &use_two_round_loading);
GetBool(params, "is_save_binary_file", &is_save_binary_file);
GetBool(params, "enable_load_from_binary_file", &enable_load_from_binary_file);
GetBool(params, "is_sigmoid", &is_sigmoid);
GetString(params, "output_model", &output_model);
GetString(params, "input_model", &input_model);
......
......@@ -18,9 +18,11 @@ namespace LightGBM {
Dataset::Dataset(const char* data_filename, const char* init_score_filename,
const IOConfig& io_config, const PredictFunction& predict_fun)
:data_filename_(data_filename), random_(io_config.data_random_seed),
max_bin_(io_config.max_bin), is_enable_sparse_(io_config.is_enable_sparse), predict_fun_(predict_fun) {
CheckCanLoadFromBin();
max_bin_(io_config.max_bin), is_enable_sparse_(io_config.is_enable_sparse),
predict_fun_(predict_fun), bin_construct_sample_cnt_(io_config.bin_construct_sample_cnt) {
if (io_config.enable_load_from_binary_file) {
CheckCanLoadFromBin();
}
if (is_loading_from_binfile_ && predict_fun != nullptr) {
Log::Info("Cannot performing initialization of prediction by using binary file, using text file instead");
is_loading_from_binfile_ = false;
......@@ -160,6 +162,17 @@ Dataset::Dataset(const char* data_filename, const char* init_score_filename,
}
Dataset::Dataset(const IOConfig& io_config, const PredictFunction& predict_fun)
:data_filename_(""), random_(io_config.data_random_seed),
max_bin_(io_config.max_bin), is_enable_sparse_(io_config.is_enable_sparse),
predict_fun_(predict_fun), bin_construct_sample_cnt_(io_config.bin_construct_sample_cnt) {
parser_ = nullptr;
text_reader_ = nullptr;
}
Dataset::~Dataset() {
if (parser_ != nullptr) { delete parser_; }
if (text_reader_ != nullptr) { delete text_reader_; }
......@@ -216,7 +229,7 @@ void Dataset::LoadDataToMemory(int rank, int num_machines, bool is_pre_partition
}
void Dataset::SampleDataFromMemory(std::vector<std::string>* out_data) {
const size_t sample_cnt = static_cast<size_t>(num_data_ < 50000 ? num_data_ : 50000);
const size_t sample_cnt = static_cast<size_t>(num_data_ < bin_construct_sample_cnt_ ? num_data_ : bin_construct_sample_cnt_);
std::vector<size_t> sample_indices = random_.Sample(num_data_, sample_cnt);
out_data->clear();
for (size_t i = 0; i < sample_indices.size(); ++i) {
......@@ -228,7 +241,7 @@ void Dataset::SampleDataFromMemory(std::vector<std::string>* out_data) {
void Dataset::SampleDataFromFile(int rank, int num_machines, bool is_pre_partition,
std::vector<std::string>* out_data) {
used_data_indices_.clear();
const size_t sample_cnt = 50000;
const data_size_t sample_cnt = static_cast<data_size_t>(bin_construct_sample_cnt_);
if (num_machines == 1 || is_pre_partition) {
num_data_ = static_cast<data_size_t>(text_reader_->SampleFromFile(random_, sample_cnt, out_data));
global_num_data_ = num_data_;
......@@ -452,8 +465,10 @@ void Dataset::LoadTrainData(int rank, int num_machines, bool is_pre_partition, b
ExtractFeaturesFromFile();
}
} else {
std::string bin_filename(data_filename_);
bin_filename.append(".bin");
// load data from binary file
LoadDataFromBinFile(rank, num_machines, is_pre_partition);
LoadDataFromBinFile(bin_filename.c_str(), rank, num_machines, is_pre_partition);
}
// check meta data
metadata_.CheckOrPartition(static_cast<data_size_t>(global_num_data_), used_data_indices_);
......@@ -497,8 +512,10 @@ void Dataset::LoadValidationData(const Dataset* train_set, bool use_two_round_lo
ExtractFeaturesFromFile();
}
} else {
std::string bin_filename(data_filename_);
bin_filename.append(".bin");
// load from binary file
LoadDataFromBinFile(0, 1, false);
LoadDataFromBinFile(bin_filename.c_str(), 0, 1, false);
}
// not need to check validation data
// check meta data
......@@ -646,19 +663,23 @@ void Dataset::ExtractFeaturesFromFile() {
}
}
void Dataset::SaveBinaryFile() {
// if is loaded from binary file, not need to save
void Dataset::SaveBinaryFile(const char* bin_filename) {
if (!is_loading_from_binfile_) {
std::string bin_filename(data_filename_);
bin_filename.append(".bin");
// if not pass a filename, just append ".bin" of original file
if (bin_filename == nullptr || bin_filename[0] == '\0') {
std::string bin_filename_str(data_filename_);
bin_filename_str.append(".bin");
bin_filename = bin_filename_str.c_str();
}
FILE* file;
#ifdef _MSC_VER
fopen_s(&file, bin_filename.c_str(), "wb");
fopen_s(&file, bin_filename, "wb");
#else
file = fopen(bin_filename.c_str(), "wb");
file = fopen(bin_filename, "wb");
#endif
if (file == NULL) {
Log::Fatal("Cannot write binary data to %s ", bin_filename.c_str());
Log::Fatal("Cannot write binary data to %s ", bin_filename);
}
Log::Info("Saving data to binary file: %s", data_filename_);
......@@ -715,20 +736,18 @@ void Dataset::CheckCanLoadFromBin() {
}
}
void Dataset::LoadDataFromBinFile(int rank, int num_machines, bool is_pre_partition) {
std::string bin_filename(data_filename_);
bin_filename.append(".bin");
void Dataset::LoadDataFromBinFile(const char* bin_filename, int rank, int num_machines, bool is_pre_partition) {
FILE* file;
#ifdef _MSC_VER
fopen_s(&file, bin_filename.c_str(), "rb");
fopen_s(&file, bin_filename, "rb");
#else
file = fopen(bin_filename.c_str(), "rb");
file = fopen(bin_filename, "rb");
#endif
if (file == NULL) {
Log::Fatal("Cannot read binary data from %s", bin_filename.c_str());
Log::Fatal("Cannot read binary data from %s", bin_filename);
}
// buffer to read binary file
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment