Commit 1765b2e3 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix partition error when set weight_colunm

parent 7150c722
...@@ -88,8 +88,6 @@ public: ...@@ -88,8 +88,6 @@ public:
void SetQuery(const data_size_t* query, data_size_t len); void SetQuery(const data_size_t* query, data_size_t len);
void SetQueryId(const data_size_t* query_id, data_size_t len);
/*! /*!
* \brief Set initial scores * \brief Set initial scores
* \param init_score Initial scores, this class will manage memory for init_score. * \param init_score Initial scores, this class will manage memory for init_score.
...@@ -244,6 +242,9 @@ private: ...@@ -244,6 +242,9 @@ private:
std::vector<data_size_t> queries_; std::vector<data_size_t> queries_;
/*! \brief mutex for threading safe call */ /*! \brief mutex for threading safe call */
std::mutex mutex_; std::mutex mutex_;
bool weight_load_from_file_;
bool query_load_from_file_;
bool init_score_load_from_file_;
}; };
......
...@@ -20,8 +20,6 @@ public: ...@@ -20,8 +20,6 @@ public:
LIGHTGBM_EXPORT Dataset* LoadFromFileAlignWithOtherDataset(const char* filename, const Dataset* train_data); LIGHTGBM_EXPORT Dataset* LoadFromFileAlignWithOtherDataset(const char* filename, const Dataset* train_data);
LIGHTGBM_EXPORT Dataset* LoadFromBinFile(const char* data_filename, const char* bin_filename, int rank, int num_machines);
LIGHTGBM_EXPORT Dataset* CostructFromSampleData(std::vector<std::vector<double>>& sample_values, size_t total_sample_size, data_size_t num_data); LIGHTGBM_EXPORT Dataset* CostructFromSampleData(std::vector<std::vector<double>>& sample_values, size_t total_sample_size, data_size_t num_data);
/*! \brief Disable copy */ /*! \brief Disable copy */
...@@ -31,6 +29,8 @@ public: ...@@ -31,6 +29,8 @@ public:
private: private:
Dataset* LoadFromBinFile(const char* data_filename, const char* bin_filename, int rank, int num_machines, int* num_global_data, std::vector<data_size_t>* used_data_indices);
void SetHeader(const char* filename); void SetHeader(const char* filename);
void CheckDataset(const Dataset* dataset); void CheckDataset(const Dataset* dataset);
......
...@@ -112,8 +112,6 @@ bool Dataset::SetIntField(const char* field_name, const int* field_data, data_si ...@@ -112,8 +112,6 @@ bool Dataset::SetIntField(const char* field_name, const int* field_data, data_si
name = Common::Trim(name); name = Common::Trim(name);
if (name == std::string("query") || name == std::string("group")) { if (name == std::string("query") || name == std::string("group")) {
metadata_.SetQuery(field_data, num_element); metadata_.SetQuery(field_data, num_element);
} else if (name == std::string("query_id") || name == std::string("group_id")) {
metadata_.SetQueryId(field_data, num_element);
} else { } else {
return false; return false;
} }
......
...@@ -209,7 +209,7 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac ...@@ -209,7 +209,7 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac
} }
} else { } else {
// load data from binary file // load data from binary file
dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), rank, num_machines)); dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), rank, num_machines, &num_global_data, &used_data_indices));
} }
// check meta data // check meta data
dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices); dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);
...@@ -255,7 +255,7 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, ...@@ -255,7 +255,7 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename,
} }
} else { } else {
// load data from binary file // load data from binary file
dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), 0, 1)); dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), 0, 1, &num_global_data, &used_data_indices));
} }
// not need to check validation data // not need to check validation data
// check meta data // check meta data
...@@ -263,7 +263,7 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, ...@@ -263,7 +263,7 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename,
return dataset.release(); return dataset.release();
} }
Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* bin_filename, int rank, int num_machines) { Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* bin_filename, int rank, int num_machines, int* num_global_data, std::vector<data_size_t>* used_data_indices) {
auto dataset = std::unique_ptr<Dataset>(new Dataset()); auto dataset = std::unique_ptr<Dataset>(new Dataset());
FILE* file; FILE* file;
#ifdef _MSC_VER #ifdef _MSC_VER
...@@ -364,8 +364,8 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b ...@@ -364,8 +364,8 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
// load meta data // load meta data
dataset->metadata_.LoadFromMemory(buffer.data()); dataset->metadata_.LoadFromMemory(buffer.data());
std::vector<data_size_t> used_data_indices; *num_global_data = dataset->num_data_;
data_size_t num_global_data = dataset->num_data_; used_data_indices->clear();
// sample local used data if need to partition // sample local used data if need to partition
if (num_machines > 1 && !io_config_.is_pre_partition) { if (num_machines > 1 && !io_config_.is_pre_partition) {
const data_size_t* query_boundaries = dataset->metadata_.query_boundaries(); const data_size_t* query_boundaries = dataset->metadata_.query_boundaries();
...@@ -373,7 +373,7 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b ...@@ -373,7 +373,7 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
// if not contain query file, minimal sample unit is one record // if not contain query file, minimal sample unit is one record
for (data_size_t i = 0; i < dataset->num_data_; ++i) { for (data_size_t i = 0; i < dataset->num_data_; ++i) {
if (random_.NextInt(0, num_machines) == rank) { if (random_.NextInt(0, num_machines) == rank) {
used_data_indices.push_back(i); used_data_indices->push_back(i);
} }
} }
} else { } else {
...@@ -394,13 +394,13 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b ...@@ -394,13 +394,13 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
++qid; ++qid;
} }
if (is_query_used) { if (is_query_used) {
used_data_indices.push_back(i); used_data_indices->push_back(i);
} }
} }
} }
dataset->num_data_ = static_cast<data_size_t>(used_data_indices.size()); dataset->num_data_ = static_cast<data_size_t>((*used_data_indices).size());
} }
dataset->metadata_.PartitionLabel(used_data_indices); dataset->metadata_.PartitionLabel(*used_data_indices);
// read feature data // read feature data
for (int i = 0; i < dataset->num_features_; ++i) { for (int i = 0; i < dataset->num_features_; ++i) {
// read feature size // read feature size
...@@ -422,8 +422,8 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b ...@@ -422,8 +422,8 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
} }
dataset->features_.emplace_back(std::unique_ptr<Feature>( dataset->features_.emplace_back(std::unique_ptr<Feature>(
new Feature(buffer.data(), new Feature(buffer.data(),
num_global_data, *num_global_data,
used_data_indices) *used_data_indices)
)); ));
} }
dataset->features_.shrink_to_fit(); dataset->features_.shrink_to_fit();
......
...@@ -12,6 +12,9 @@ Metadata::Metadata() { ...@@ -12,6 +12,9 @@ Metadata::Metadata() {
num_init_score_ = 0; num_init_score_ = 0;
num_data_ = 0; num_data_ = 0;
num_queries_ = 0; num_queries_ = 0;
weight_load_from_file_ = false;
query_load_from_file_ = false;
init_score_load_from_file_ = false;
} }
void Metadata::Init(const char * data_filename) { void Metadata::Init(const char * data_filename) {
...@@ -40,6 +43,7 @@ void Metadata::Init(data_size_t num_data, int weight_idx, int query_idx) { ...@@ -40,6 +43,7 @@ void Metadata::Init(data_size_t num_data, int weight_idx, int query_idx) {
for (data_size_t i = 0; i < num_weights_; ++i) { for (data_size_t i = 0; i < num_weights_; ++i) {
weights_[i] = 0.0f; weights_[i] = 0.0f;
} }
weight_load_from_file_ = false;
} }
if (query_idx >= 0) { if (query_idx >= 0) {
if (!query_boundaries_.empty()) { if (!query_boundaries_.empty()) {
...@@ -52,6 +56,7 @@ void Metadata::Init(data_size_t num_data, int weight_idx, int query_idx) { ...@@ -52,6 +56,7 @@ void Metadata::Init(data_size_t num_data, int weight_idx, int query_idx) {
for (data_size_t i = 0; i < num_data_; ++i) { for (data_size_t i = 0; i < num_data_; ++i) {
queries_[i] = 0; queries_[i] = 0;
} }
query_load_from_file_ = false;
} }
} }
...@@ -185,27 +190,17 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data ...@@ -185,27 +190,17 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
Log::Fatal("Initial score size doesn't match data size"); Log::Fatal("Initial score size doesn't match data size");
} }
} else { } else {
if (!queries_.empty()) {
Log::Fatal("Cannot used query_id for parallel training");
}
data_size_t num_used_data = static_cast<data_size_t>(used_data_indices.size()); data_size_t num_used_data = static_cast<data_size_t>(used_data_indices.size());
// check weights // check weights
if (weight_load_from_file_) {
if (weights_.size() > 0 && num_weights_ != num_all_data) { if (weights_.size() > 0 && num_weights_ != num_all_data) {
weights_.clear(); weights_.clear();
num_weights_ = 0; num_weights_ = 0;
Log::Fatal("Weights size doesn't match data size"); Log::Fatal("Weights size doesn't match data size");
} }
// check query boundries
if (!query_boundaries_.empty() && query_boundaries_[num_queries_] != num_all_data) {
query_boundaries_.clear();
num_queries_ = 0;
Log::Fatal("Query size doesn't match data size");
}
// contain initial score file
if (!init_score_.empty() && (num_init_score_ % num_all_data) != 0) {
init_score_.clear();
num_init_score_ = 0;
Log::Fatal("Initial score size doesn't match data size");
}
// get local weights // get local weights
if (!weights_.empty()) { if (!weights_.empty()) {
auto old_weights = weights_; auto old_weights = weights_;
...@@ -217,7 +212,14 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data ...@@ -217,7 +212,14 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
} }
old_weights.clear(); old_weights.clear();
} }
}
if (query_load_from_file_) {
// check query boundries
if (!query_boundaries_.empty() && query_boundaries_[num_queries_] != num_all_data) {
query_boundaries_.clear();
num_queries_ = 0;
Log::Fatal("Query size doesn't match data size");
}
// get local query boundaries // get local query boundaries
if (!query_boundaries_.empty()) { if (!query_boundaries_.empty()) {
std::vector<data_size_t> used_query; std::vector<data_size_t> used_query;
...@@ -250,6 +252,14 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data ...@@ -250,6 +252,14 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
} }
old_query_boundaries.clear(); old_query_boundaries.clear();
} }
}
if (init_score_load_from_file_) {
// contain initial score file
if (!init_score_.empty() && (num_init_score_ % num_all_data) != 0) {
init_score_.clear();
num_init_score_ = 0;
Log::Fatal("Initial score size doesn't match data size");
}
// get local initial scores // get local initial scores
if (!init_score_.empty()) { if (!init_score_.empty()) {
...@@ -258,14 +268,14 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data ...@@ -258,14 +268,14 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
num_init_score_ = static_cast<int64_t>(num_data_) * num_class; num_init_score_ = static_cast<int64_t>(num_data_) * num_class;
init_score_ = std::vector<double>(num_init_score_); init_score_ = std::vector<double>(num_init_score_);
#pragma omp parallel for schedule(static) #pragma omp parallel for schedule(static)
for (int k = 0; k < num_class; ++k){ for (int k = 0; k < num_class; ++k) {
for (size_t i = 0; i < used_data_indices.size(); ++i) { for (size_t i = 0; i < used_data_indices.size(); ++i) {
init_score_[k * num_data_ + i] = old_scores[k * num_all_data + used_data_indices[i]]; init_score_[k * num_data_ + i] = old_scores[k * num_all_data + used_data_indices[i]];
} }
} }
old_scores.clear(); old_scores.clear();
} }
}
// re-load query weight // re-load query weight
LoadQueryWeights(); LoadQueryWeights();
} }
...@@ -289,6 +299,7 @@ void Metadata::SetInitScore(const double* init_score, data_size_t len) { ...@@ -289,6 +299,7 @@ void Metadata::SetInitScore(const double* init_score, data_size_t len) {
for (int64_t i = 0; i < num_init_score_; ++i) { for (int64_t i = 0; i < num_init_score_; ++i) {
init_score_[i] = init_score[i]; init_score_[i] = init_score[i];
} }
init_score_load_from_file_ = false;
} }
void Metadata::SetLabel(const float* label, data_size_t len) { void Metadata::SetLabel(const float* label, data_size_t len) {
...@@ -326,6 +337,7 @@ void Metadata::SetWeights(const float* weights, data_size_t len) { ...@@ -326,6 +337,7 @@ void Metadata::SetWeights(const float* weights, data_size_t len) {
weights_[i] = weights[i]; weights_[i] = weights[i];
} }
LoadQueryWeights(); LoadQueryWeights();
weight_load_from_file_ = false;
} }
void Metadata::SetQuery(const data_size_t* query, data_size_t len) { void Metadata::SetQuery(const data_size_t* query, data_size_t len) {
...@@ -352,48 +364,7 @@ void Metadata::SetQuery(const data_size_t* query, data_size_t len) { ...@@ -352,48 +364,7 @@ void Metadata::SetQuery(const data_size_t* query, data_size_t len) {
query_boundaries_[i + 1] = query_boundaries_[i] + query[i]; query_boundaries_[i + 1] = query_boundaries_[i] + query[i];
} }
LoadQueryWeights(); LoadQueryWeights();
} query_load_from_file_ = false;
void Metadata::SetQueryId(const data_size_t* query_id, data_size_t len) {
std::lock_guard<std::mutex> lock(mutex_);
// save to nullptr
if (query_id == nullptr || len == 0) {
query_boundaries_.clear();
queries_.clear();
num_queries_ = 0;
return;
}
if (num_data_ != len) {
Log::Fatal("len of query id is not same with #data");
}
if (!queries_.empty()) { queries_.clear(); }
queries_ = std::vector<data_size_t>(num_data_);
for (data_size_t i = 0; i < num_weights_; ++i) {
queries_[i] = query_id[i];
}
// need convert query_id to boundaries
std::vector<data_size_t> tmp_buffer;
data_size_t last_qid = -1;
data_size_t cur_cnt = 0;
for (data_size_t i = 0; i < num_data_; ++i) {
if (last_qid != queries_[i]) {
if (cur_cnt > 0) {
tmp_buffer.push_back(cur_cnt);
}
cur_cnt = 0;
last_qid = queries_[i];
}
++cur_cnt;
}
tmp_buffer.push_back(cur_cnt);
query_boundaries_ = std::vector<data_size_t>(tmp_buffer.size() + 1);
num_queries_ = static_cast<data_size_t>(tmp_buffer.size());
query_boundaries_[0] = 0;
for (size_t i = 0; i < tmp_buffer.size(); ++i) {
query_boundaries_[i + 1] = query_boundaries_[i] + tmp_buffer[i];
}
queries_.clear();
LoadQueryWeights();
} }
void Metadata::LoadWeights() { void Metadata::LoadWeights() {
...@@ -415,6 +386,7 @@ void Metadata::LoadWeights() { ...@@ -415,6 +386,7 @@ void Metadata::LoadWeights() {
Common::Atof(reader.Lines()[i].c_str(), &tmp_weight); Common::Atof(reader.Lines()[i].c_str(), &tmp_weight);
weights_[i] = static_cast<float>(tmp_weight); weights_[i] = static_cast<float>(tmp_weight);
} }
weight_load_from_file_ = true;
} }
void Metadata::LoadInitialScore() { void Metadata::LoadInitialScore() {
...@@ -457,6 +429,7 @@ void Metadata::LoadInitialScore() { ...@@ -457,6 +429,7 @@ void Metadata::LoadInitialScore() {
} }
} }
} }
init_score_load_from_file_ = true;
} }
void Metadata::LoadQueryBoundaries() { void Metadata::LoadQueryBoundaries() {
...@@ -478,6 +451,7 @@ void Metadata::LoadQueryBoundaries() { ...@@ -478,6 +451,7 @@ void Metadata::LoadQueryBoundaries() {
Common::Atoi(reader.Lines()[i].c_str(), &tmp_cnt); Common::Atoi(reader.Lines()[i].c_str(), &tmp_cnt);
query_boundaries_[i + 1] = query_boundaries_[i] + static_cast<data_size_t>(tmp_cnt); query_boundaries_[i + 1] = query_boundaries_[i] + static_cast<data_size_t>(tmp_cnt);
} }
query_load_from_file_ = true;
} }
void Metadata::LoadQueryWeights() { void Metadata::LoadQueryWeights() {
...@@ -516,12 +490,14 @@ void Metadata::LoadFromMemory(const void* memory) { ...@@ -516,12 +490,14 @@ void Metadata::LoadFromMemory(const void* memory) {
weights_ = std::vector<float>(num_weights_); weights_ = std::vector<float>(num_weights_);
std::memcpy(weights_.data(), mem_ptr, sizeof(float)*num_weights_); std::memcpy(weights_.data(), mem_ptr, sizeof(float)*num_weights_);
mem_ptr += sizeof(float)*num_weights_; mem_ptr += sizeof(float)*num_weights_;
weight_load_from_file_ = true;
} }
if (num_queries_ > 0) { if (num_queries_ > 0) {
if (!query_boundaries_.empty()) { query_boundaries_.clear(); } if (!query_boundaries_.empty()) { query_boundaries_.clear(); }
query_boundaries_ = std::vector<data_size_t>(num_queries_ + 1); query_boundaries_ = std::vector<data_size_t>(num_queries_ + 1);
std::memcpy(query_boundaries_.data(), mem_ptr, sizeof(data_size_t)*(num_queries_ + 1)); std::memcpy(query_boundaries_.data(), mem_ptr, sizeof(data_size_t)*(num_queries_ + 1));
mem_ptr += sizeof(data_size_t)*(num_queries_ + 1); mem_ptr += sizeof(data_size_t)*(num_queries_ + 1);
query_load_from_file_ = true;
} }
LoadQueryWeights(); LoadQueryWeights();
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment