Commit 3b161919 authored by Guolin Ke's avatar Guolin Ke
Browse files

support specific path of initial scores.

parent 850f0391
...@@ -207,7 +207,11 @@ The parameter format is `key1=value1 key2=value2 ... ` . And parameters can be s ...@@ -207,7 +207,11 @@ The parameter format is `key1=value1 key2=value2 ... ` . And parameters can be s
* `zero_as_missing`, default=`false`, type=bool * `zero_as_missing`, default=`false`, type=bool
* Set to `true` will treat all zero as missing values (including the unshown values in libsvm/sparse matrics). * Set to `true` will treat all zero as missing values (including the unshown values in libsvm/sparse matrics).
* Set to `false` will use `na` to represent missing values. * Set to `false` will use `na` to represent missing values.
* `init_score_file`, default=`""`, type=string
* Path of training initial score file, `""` will use `train_data_file+".init"` (if exists).
* `valid_init_score_file`, default=`""`, type=multi-string
* Path of validation initial score file, `""` will use `valid_data_file+".init"` (if exists).
* separate by `,` for multi-validation data
## Objective parameters ## Objective parameters
......
...@@ -95,7 +95,9 @@ public: ...@@ -95,7 +95,9 @@ public:
int num_class = 1; int num_class = 1;
int data_random_seed = 1; int data_random_seed = 1;
std::string data_filename = ""; std::string data_filename = "";
std::string initscore_filename = "";
std::vector<std::string> valid_data_filenames; std::vector<std::string> valid_data_filenames;
std::vector<std::string> valid_data_initscores;
int snapshot_freq = -1; int snapshot_freq = -1;
std::string output_model = "LightGBM_model.txt"; std::string output_model = "LightGBM_model.txt";
std::string output_result = "LightGBM_predict_result.txt"; std::string output_result = "LightGBM_predict_result.txt";
...@@ -455,7 +457,8 @@ struct ParameterAlias { ...@@ -455,7 +457,8 @@ struct ParameterAlias {
"feature_fraction_seed", "enable_bundle", "data_filename", "valid_data_filenames", "feature_fraction_seed", "enable_bundle", "data_filename", "valid_data_filenames",
"snapshot_freq", "verbosity", "sparse_threshold", "enable_load_from_binary_file", "snapshot_freq", "verbosity", "sparse_threshold", "enable_load_from_binary_file",
"max_conflict_rate", "poisson_max_delta_step", "gaussian_eta", "max_conflict_rate", "poisson_max_delta_step", "gaussian_eta",
"histogram_pool_size", "output_freq", "is_provide_training_metric", "machine_list_filename", "zero_as_missing" "histogram_pool_size", "output_freq", "is_provide_training_metric", "machine_list_filename", "zero_as_missing",
"init_score_file", "valid_init_score_file"
}); });
std::unordered_map<std::string, std::string> tmp_map; std::unordered_map<std::string, std::string> tmp_map;
for (const auto& pair : *params) { for (const auto& pair : *params) {
......
...@@ -44,7 +44,7 @@ public: ...@@ -44,7 +44,7 @@ public:
* \param data_filename Filename of data * \param data_filename Filename of data
* \param init_score_filename Filename of initial score * \param init_score_filename Filename of initial score
*/ */
void Init(const char* data_filename); void Init(const char* data_filename, const char* initscore_file);
/*! /*!
* \brief init as subset * \brief init as subset
* \param metadata Filename of data * \param metadata Filename of data
...@@ -211,7 +211,7 @@ public: ...@@ -211,7 +211,7 @@ public:
private: private:
/*! \brief Load initial scores from file */ /*! \brief Load initial scores from file */
void LoadInitialScore(); void LoadInitialScore(const char* initscore_file);
/*! \brief Load wights from file */ /*! \brief Load wights from file */
void LoadWeights(); void LoadWeights();
/*! \brief Load query boundaries from file */ /*! \brief Load query boundaries from file */
......
...@@ -12,13 +12,13 @@ public: ...@@ -12,13 +12,13 @@ public:
LIGHTGBM_EXPORT ~DatasetLoader(); LIGHTGBM_EXPORT ~DatasetLoader();
LIGHTGBM_EXPORT Dataset* LoadFromFile(const char* filename, int rank, int num_machines); LIGHTGBM_EXPORT Dataset* LoadFromFile(const char* filename, const char* initscore_file, int rank, int num_machines);
LIGHTGBM_EXPORT Dataset* LoadFromFile(const char* filename) { LIGHTGBM_EXPORT Dataset* LoadFromFile(const char* filename, const char* initscore_file) {
return LoadFromFile(filename, 0, 1); return LoadFromFile(filename, initscore_file, 0, 1);
} }
LIGHTGBM_EXPORT Dataset* LoadFromFileAlignWithOtherDataset(const char* filename, const Dataset* train_data); LIGHTGBM_EXPORT Dataset* LoadFromFileAlignWithOtherDataset(const char* filename, const char* initscore_file, const Dataset* train_data);
LIGHTGBM_EXPORT Dataset* CostructFromSampleData(double** sample_values, LIGHTGBM_EXPORT Dataset* CostructFromSampleData(double** sample_values,
int** sample_indices, int num_col, const int* num_per_col, int** sample_indices, int num_col, const int* num_per_col,
......
...@@ -126,10 +126,12 @@ void Application::LoadData() { ...@@ -126,10 +126,12 @@ void Application::LoadData() {
if (config_.is_parallel_find_bin) { if (config_.is_parallel_find_bin) {
// load data for parallel training // load data for parallel training
train_data_.reset(dataset_loader.LoadFromFile(config_.io_config.data_filename.c_str(), train_data_.reset(dataset_loader.LoadFromFile(config_.io_config.data_filename.c_str(),
config_.io_config.initscore_filename.c_str(),
Network::rank(), Network::num_machines())); Network::rank(), Network::num_machines()));
} else { } else {
// load data for single machine // load data for single machine
train_data_.reset(dataset_loader.LoadFromFile(config_.io_config.data_filename.c_str(), 0, 1)); train_data_.reset(dataset_loader.LoadFromFile(config_.io_config.data_filename.c_str(), config_.io_config.initscore_filename.c_str(),
0, 1));
} }
// need save binary file // need save binary file
if (config_.io_config.is_save_binary_file) { if (config_.io_config.is_save_binary_file) {
...@@ -156,6 +158,7 @@ void Application::LoadData() { ...@@ -156,6 +158,7 @@ void Application::LoadData() {
auto new_dataset = std::unique_ptr<Dataset>( auto new_dataset = std::unique_ptr<Dataset>(
dataset_loader.LoadFromFileAlignWithOtherDataset( dataset_loader.LoadFromFileAlignWithOtherDataset(
config_.io_config.valid_data_filenames[i].c_str(), config_.io_config.valid_data_filenames[i].c_str(),
config_.io_config.valid_data_initscores[i].c_str(),
train_data_.get()) train_data_.get())
); );
valid_datas_.push_back(std::move(new_dataset)); valid_datas_.push_back(std::move(new_dataset));
......
...@@ -350,11 +350,11 @@ int LGBM_DatasetCreateFromFile(const char* filename, ...@@ -350,11 +350,11 @@ int LGBM_DatasetCreateFromFile(const char* filename,
if (config.num_threads > 0) { if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads); omp_set_num_threads(config.num_threads);
} }
DatasetLoader loader(config.io_config, nullptr, 1, filename); DatasetLoader loader(config.io_config,nullptr, 1, filename);
if (reference == nullptr) { if (reference == nullptr) {
*out = loader.LoadFromFile(filename); *out = loader.LoadFromFile(filename, "");
} else { } else {
*out = loader.LoadFromFileAlignWithOtherDataset(filename, *out = loader.LoadFromFileAlignWithOtherDataset(filename, "",
reinterpret_cast<const Dataset*>(reference)); reinterpret_cast<const Dataset*>(reference));
} }
API_END(); API_END();
......
...@@ -245,6 +245,7 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) { ...@@ -245,6 +245,7 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "num_class", &num_class); GetInt(params, "num_class", &num_class);
GetInt(params, "data_random_seed", &data_random_seed); GetInt(params, "data_random_seed", &data_random_seed);
GetString(params, "data", &data_filename); GetString(params, "data", &data_filename);
GetString(params, "init_score_file", &initscore_filename);
GetInt(params, "verbose", &verbosity); GetInt(params, "verbose", &verbosity);
GetInt(params, "num_iteration_predict", &num_iteration_predict); GetInt(params, "num_iteration_predict", &num_iteration_predict);
GetInt(params, "bin_construct_sample_cnt", &bin_construct_sample_cnt); GetInt(params, "bin_construct_sample_cnt", &bin_construct_sample_cnt);
...@@ -265,6 +266,12 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) { ...@@ -265,6 +266,12 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
if (GetString(params, "valid_data", &tmp_str)) { if (GetString(params, "valid_data", &tmp_str)) {
valid_data_filenames = Common::Split(tmp_str.c_str(), ','); valid_data_filenames = Common::Split(tmp_str.c_str(), ',');
} }
if (GetString(params, "valid_init_score_file", &tmp_str)) {
valid_data_initscores = Common::Split(tmp_str.c_str(), ',');
} else {
valid_data_initscores = std::vector<std::string>(valid_data_filenames.size(), "");
}
CHECK(valid_data_filenames.size() == valid_data_initscores.size());
GetBool(params, "has_header", &has_header); GetBool(params, "has_header", &has_header);
GetString(params, "label_column", &label_column); GetString(params, "label_column", &label_column);
GetString(params, "weight_column", &weight_column); GetString(params, "weight_column", &weight_column);
......
...@@ -156,7 +156,7 @@ void DatasetLoader::SetHeader(const char* filename) { ...@@ -156,7 +156,7 @@ void DatasetLoader::SetHeader(const char* filename) {
} }
} }
Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_machines) { Dataset* DatasetLoader::LoadFromFile(const char* filename, const char* initscore_file, int rank, int num_machines) {
// don't support query id in data file when training in parallel // don't support query id in data file when training in parallel
if (num_machines > 1 && !io_config_.is_pre_partition) { if (num_machines > 1 && !io_config_.is_pre_partition) {
if (group_idx_ > 0) { if (group_idx_ > 0) {
...@@ -175,7 +175,7 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac ...@@ -175,7 +175,7 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac
} }
dataset->data_filename_ = filename; dataset->data_filename_ = filename;
dataset->label_idx_ = label_idx_; dataset->label_idx_ = label_idx_;
dataset->metadata_.Init(filename); dataset->metadata_.Init(filename, initscore_file);
if (!io_config_.use_two_round_loading) { if (!io_config_.use_two_round_loading) {
// read data to memory // read data to memory
auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices); auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
...@@ -218,7 +218,7 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac ...@@ -218,7 +218,7 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac
Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, const Dataset* train_data) { Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, const char* initscore_file, const Dataset* train_data) {
data_size_t num_global_data = 0; data_size_t num_global_data = 0;
std::vector<data_size_t> used_data_indices; std::vector<data_size_t> used_data_indices;
auto dataset = std::unique_ptr<Dataset>(new Dataset()); auto dataset = std::unique_ptr<Dataset>(new Dataset());
...@@ -230,7 +230,7 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, ...@@ -230,7 +230,7 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename,
} }
dataset->data_filename_ = filename; dataset->data_filename_ = filename;
dataset->label_idx_ = label_idx_; dataset->label_idx_ = label_idx_;
dataset->metadata_.Init(filename); dataset->metadata_.Init(filename, initscore_file);
if (!io_config_.use_two_round_loading) { if (!io_config_.use_two_round_loading) {
// read data in memory // read data in memory
auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, 0, 1, &num_global_data, &used_data_indices); auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, 0, 1, &num_global_data, &used_data_indices);
......
...@@ -17,13 +17,13 @@ Metadata::Metadata() { ...@@ -17,13 +17,13 @@ Metadata::Metadata() {
init_score_load_from_file_ = false; init_score_load_from_file_ = false;
} }
void Metadata::Init(const char * data_filename) { void Metadata::Init(const char * data_filename, const char* initscore_file) {
data_filename_ = data_filename; data_filename_ = data_filename;
// for lambdarank, it needs query data for partition data in parallel learning // for lambdarank, it needs query data for partition data in parallel learning
LoadQueryBoundaries(); LoadQueryBoundaries();
LoadWeights(); LoadWeights();
LoadQueryWeights(); LoadQueryWeights();
LoadInitialScore(); LoadInitialScore(initscore_file);
} }
Metadata::~Metadata() { Metadata::~Metadata() {
...@@ -389,11 +389,14 @@ void Metadata::LoadWeights() { ...@@ -389,11 +389,14 @@ void Metadata::LoadWeights() {
weight_load_from_file_ = true; weight_load_from_file_ = true;
} }
void Metadata::LoadInitialScore() { void Metadata::LoadInitialScore(const char* initscore_file) {
num_init_score_ = 0; num_init_score_ = 0;
std::string init_score_filename(data_filename_); std::string init_score_filename(initscore_file);
// default weight file name if (init_score_filename.size() <= 0) {
init_score_filename.append(".init"); init_score_filename = std::string(data_filename_);
// default weight file name
init_score_filename.append(".init");
}
TextReader<size_t> reader(init_score_filename.c_str(), false); TextReader<size_t> reader(init_score_filename.c_str(), false);
reader.ReadAllLines(); reader.ReadAllLines();
if (reader.Lines().empty()) { if (reader.Lines().empty()) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment