Commit 3b161919 authored by Guolin Ke's avatar Guolin Ke
Browse files

support specific path of initial scores.

parent 850f0391
......@@ -207,7 +207,11 @@ The parameter format is `key1=value1 key2=value2 ... ` . And parameters can be s
* `zero_as_missing`, default=`false`, type=bool
* Set to `true` will treat all zero as missing values (including the unshown values in libsvm/sparse matrics).
* Set to `false` will use `na` to represent missing values.
* `init_score_file`, default=`""`, type=string
* Path of training initial score file, `""` will use `train_data_file+".init"` (if exists).
* `valid_init_score_file`, default=`""`, type=multi-string
* Path of validation initial score file, `""` will use `valid_data_file+".init"` (if exists).
* separate by `,` for multi-validation data
## Objective parameters
......
......@@ -95,7 +95,9 @@ public:
int num_class = 1;
int data_random_seed = 1;
std::string data_filename = "";
std::string initscore_filename = "";
std::vector<std::string> valid_data_filenames;
std::vector<std::string> valid_data_initscores;
int snapshot_freq = -1;
std::string output_model = "LightGBM_model.txt";
std::string output_result = "LightGBM_predict_result.txt";
......@@ -455,7 +457,8 @@ struct ParameterAlias {
"feature_fraction_seed", "enable_bundle", "data_filename", "valid_data_filenames",
"snapshot_freq", "verbosity", "sparse_threshold", "enable_load_from_binary_file",
"max_conflict_rate", "poisson_max_delta_step", "gaussian_eta",
"histogram_pool_size", "output_freq", "is_provide_training_metric", "machine_list_filename", "zero_as_missing"
"histogram_pool_size", "output_freq", "is_provide_training_metric", "machine_list_filename", "zero_as_missing",
"init_score_file", "valid_init_score_file"
});
std::unordered_map<std::string, std::string> tmp_map;
for (const auto& pair : *params) {
......
......@@ -44,7 +44,7 @@ public:
* \param data_filename Filename of data
* \param init_score_filename Filename of initial score
*/
void Init(const char* data_filename);
void Init(const char* data_filename, const char* initscore_file);
/*!
* \brief init as subset
* \param metadata Filename of data
......@@ -211,7 +211,7 @@ public:
private:
/*! \brief Load initial scores from file */
void LoadInitialScore();
void LoadInitialScore(const char* initscore_file);
/*! \brief Load wights from file */
void LoadWeights();
/*! \brief Load query boundaries from file */
......
......@@ -12,13 +12,13 @@ public:
LIGHTGBM_EXPORT ~DatasetLoader();
LIGHTGBM_EXPORT Dataset* LoadFromFile(const char* filename, int rank, int num_machines);
LIGHTGBM_EXPORT Dataset* LoadFromFile(const char* filename, const char* initscore_file, int rank, int num_machines);
LIGHTGBM_EXPORT Dataset* LoadFromFile(const char* filename) {
return LoadFromFile(filename, 0, 1);
LIGHTGBM_EXPORT Dataset* LoadFromFile(const char* filename, const char* initscore_file) {
return LoadFromFile(filename, initscore_file, 0, 1);
}
LIGHTGBM_EXPORT Dataset* LoadFromFileAlignWithOtherDataset(const char* filename, const Dataset* train_data);
LIGHTGBM_EXPORT Dataset* LoadFromFileAlignWithOtherDataset(const char* filename, const char* initscore_file, const Dataset* train_data);
LIGHTGBM_EXPORT Dataset* CostructFromSampleData(double** sample_values,
int** sample_indices, int num_col, const int* num_per_col,
......
......@@ -126,10 +126,12 @@ void Application::LoadData() {
if (config_.is_parallel_find_bin) {
// load data for parallel training
train_data_.reset(dataset_loader.LoadFromFile(config_.io_config.data_filename.c_str(),
config_.io_config.initscore_filename.c_str(),
Network::rank(), Network::num_machines()));
} else {
// load data for single machine
train_data_.reset(dataset_loader.LoadFromFile(config_.io_config.data_filename.c_str(), 0, 1));
train_data_.reset(dataset_loader.LoadFromFile(config_.io_config.data_filename.c_str(), config_.io_config.initscore_filename.c_str(),
0, 1));
}
// need save binary file
if (config_.io_config.is_save_binary_file) {
......@@ -156,6 +158,7 @@ void Application::LoadData() {
auto new_dataset = std::unique_ptr<Dataset>(
dataset_loader.LoadFromFileAlignWithOtherDataset(
config_.io_config.valid_data_filenames[i].c_str(),
config_.io_config.valid_data_initscores[i].c_str(),
train_data_.get())
);
valid_datas_.push_back(std::move(new_dataset));
......
......@@ -350,11 +350,11 @@ int LGBM_DatasetCreateFromFile(const char* filename,
if (config.num_threads > 0) {
omp_set_num_threads(config.num_threads);
}
DatasetLoader loader(config.io_config, nullptr, 1, filename);
DatasetLoader loader(config.io_config,nullptr, 1, filename);
if (reference == nullptr) {
*out = loader.LoadFromFile(filename);
*out = loader.LoadFromFile(filename, "");
} else {
*out = loader.LoadFromFileAlignWithOtherDataset(filename,
*out = loader.LoadFromFileAlignWithOtherDataset(filename, "",
reinterpret_cast<const Dataset*>(reference));
}
API_END();
......
......@@ -245,6 +245,7 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
GetInt(params, "num_class", &num_class);
GetInt(params, "data_random_seed", &data_random_seed);
GetString(params, "data", &data_filename);
GetString(params, "init_score_file", &initscore_filename);
GetInt(params, "verbose", &verbosity);
GetInt(params, "num_iteration_predict", &num_iteration_predict);
GetInt(params, "bin_construct_sample_cnt", &bin_construct_sample_cnt);
......@@ -265,6 +266,12 @@ void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
if (GetString(params, "valid_data", &tmp_str)) {
valid_data_filenames = Common::Split(tmp_str.c_str(), ',');
}
if (GetString(params, "valid_init_score_file", &tmp_str)) {
valid_data_initscores = Common::Split(tmp_str.c_str(), ',');
} else {
valid_data_initscores = std::vector<std::string>(valid_data_filenames.size(), "");
}
CHECK(valid_data_filenames.size() == valid_data_initscores.size());
GetBool(params, "has_header", &has_header);
GetString(params, "label_column", &label_column);
GetString(params, "weight_column", &weight_column);
......
......@@ -156,7 +156,7 @@ void DatasetLoader::SetHeader(const char* filename) {
}
}
Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_machines) {
Dataset* DatasetLoader::LoadFromFile(const char* filename, const char* initscore_file, int rank, int num_machines) {
// don't support query id in data file when training in parallel
if (num_machines > 1 && !io_config_.is_pre_partition) {
if (group_idx_ > 0) {
......@@ -175,7 +175,7 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac
}
dataset->data_filename_ = filename;
dataset->label_idx_ = label_idx_;
dataset->metadata_.Init(filename);
dataset->metadata_.Init(filename, initscore_file);
if (!io_config_.use_two_round_loading) {
// read data to memory
auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
......@@ -218,7 +218,7 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac
Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, const Dataset* train_data) {
Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, const char* initscore_file, const Dataset* train_data) {
data_size_t num_global_data = 0;
std::vector<data_size_t> used_data_indices;
auto dataset = std::unique_ptr<Dataset>(new Dataset());
......@@ -230,7 +230,7 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename,
}
dataset->data_filename_ = filename;
dataset->label_idx_ = label_idx_;
dataset->metadata_.Init(filename);
dataset->metadata_.Init(filename, initscore_file);
if (!io_config_.use_two_round_loading) {
// read data in memory
auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, 0, 1, &num_global_data, &used_data_indices);
......
......@@ -17,13 +17,13 @@ Metadata::Metadata() {
init_score_load_from_file_ = false;
}
void Metadata::Init(const char * data_filename) {
void Metadata::Init(const char * data_filename, const char* initscore_file) {
data_filename_ = data_filename;
// for lambdarank, it needs query data for partition data in parallel learning
LoadQueryBoundaries();
LoadWeights();
LoadQueryWeights();
LoadInitialScore();
LoadInitialScore(initscore_file);
}
Metadata::~Metadata() {
......@@ -389,11 +389,14 @@ void Metadata::LoadWeights() {
weight_load_from_file_ = true;
}
void Metadata::LoadInitialScore() {
void Metadata::LoadInitialScore(const char* initscore_file) {
num_init_score_ = 0;
std::string init_score_filename(data_filename_);
// default weight file name
init_score_filename.append(".init");
std::string init_score_filename(initscore_file);
if (init_score_filename.size() <= 0) {
init_score_filename = std::string(data_filename_);
// default weight file name
init_score_filename.append(".init");
}
TextReader<size_t> reader(init_score_filename.c_str(), false);
reader.ReadAllLines();
if (reader.Lines().empty()) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment