Commit 1765b2e3 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix partition error when set weight_colunm

parent 7150c722
......@@ -88,8 +88,6 @@ public:
void SetQuery(const data_size_t* query, data_size_t len);
void SetQueryId(const data_size_t* query_id, data_size_t len);
/*!
* \brief Set initial scores
* \param init_score Initial scores, this class will manage memory for init_score.
......@@ -244,6 +242,9 @@ private:
std::vector<data_size_t> queries_;
/*! \brief mutex for threading safe call */
std::mutex mutex_;
bool weight_load_from_file_;
bool query_load_from_file_;
bool init_score_load_from_file_;
};
......
......@@ -20,8 +20,6 @@ public:
LIGHTGBM_EXPORT Dataset* LoadFromFileAlignWithOtherDataset(const char* filename, const Dataset* train_data);
LIGHTGBM_EXPORT Dataset* LoadFromBinFile(const char* data_filename, const char* bin_filename, int rank, int num_machines);
LIGHTGBM_EXPORT Dataset* CostructFromSampleData(std::vector<std::vector<double>>& sample_values, size_t total_sample_size, data_size_t num_data);
/*! \brief Disable copy */
......@@ -31,6 +29,8 @@ public:
private:
Dataset* LoadFromBinFile(const char* data_filename, const char* bin_filename, int rank, int num_machines, int* num_global_data, std::vector<data_size_t>* used_data_indices);
void SetHeader(const char* filename);
void CheckDataset(const Dataset* dataset);
......
......@@ -112,8 +112,6 @@ bool Dataset::SetIntField(const char* field_name, const int* field_data, data_si
name = Common::Trim(name);
if (name == std::string("query") || name == std::string("group")) {
metadata_.SetQuery(field_data, num_element);
} else if (name == std::string("query_id") || name == std::string("group_id")) {
metadata_.SetQueryId(field_data, num_element);
} else {
return false;
}
......
......@@ -209,7 +209,7 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac
}
} else {
// load data from binary file
dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), rank, num_machines));
dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), rank, num_machines, &num_global_data, &used_data_indices));
}
// check meta data
dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);
......@@ -255,7 +255,7 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename,
}
} else {
// load data from binary file
dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), 0, 1));
dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), 0, 1, &num_global_data, &used_data_indices));
}
// not need to check validation data
// check meta data
......@@ -263,7 +263,7 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename,
return dataset.release();
}
Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* bin_filename, int rank, int num_machines) {
Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* bin_filename, int rank, int num_machines, int* num_global_data, std::vector<data_size_t>* used_data_indices) {
auto dataset = std::unique_ptr<Dataset>(new Dataset());
FILE* file;
#ifdef _MSC_VER
......@@ -364,8 +364,8 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
// load meta data
dataset->metadata_.LoadFromMemory(buffer.data());
std::vector<data_size_t> used_data_indices;
data_size_t num_global_data = dataset->num_data_;
*num_global_data = dataset->num_data_;
used_data_indices->clear();
// sample local used data if need to partition
if (num_machines > 1 && !io_config_.is_pre_partition) {
const data_size_t* query_boundaries = dataset->metadata_.query_boundaries();
......@@ -373,7 +373,7 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
// if not contain query file, minimal sample unit is one record
for (data_size_t i = 0; i < dataset->num_data_; ++i) {
if (random_.NextInt(0, num_machines) == rank) {
used_data_indices.push_back(i);
used_data_indices->push_back(i);
}
}
} else {
......@@ -394,13 +394,13 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
++qid;
}
if (is_query_used) {
used_data_indices.push_back(i);
used_data_indices->push_back(i);
}
}
}
dataset->num_data_ = static_cast<data_size_t>(used_data_indices.size());
dataset->num_data_ = static_cast<data_size_t>((*used_data_indices).size());
}
dataset->metadata_.PartitionLabel(used_data_indices);
dataset->metadata_.PartitionLabel(*used_data_indices);
// read feature data
for (int i = 0; i < dataset->num_features_; ++i) {
// read feature size
......@@ -422,8 +422,8 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
}
dataset->features_.emplace_back(std::unique_ptr<Feature>(
new Feature(buffer.data(),
num_global_data,
used_data_indices)
*num_global_data,
*used_data_indices)
));
}
dataset->features_.shrink_to_fit();
......
......@@ -12,6 +12,9 @@ Metadata::Metadata() {
num_init_score_ = 0;
num_data_ = 0;
num_queries_ = 0;
weight_load_from_file_ = false;
query_load_from_file_ = false;
init_score_load_from_file_ = false;
}
void Metadata::Init(const char * data_filename) {
......@@ -40,6 +43,7 @@ void Metadata::Init(data_size_t num_data, int weight_idx, int query_idx) {
for (data_size_t i = 0; i < num_weights_; ++i) {
weights_[i] = 0.0f;
}
weight_load_from_file_ = false;
}
if (query_idx >= 0) {
if (!query_boundaries_.empty()) {
......@@ -52,6 +56,7 @@ void Metadata::Init(data_size_t num_data, int weight_idx, int query_idx) {
for (data_size_t i = 0; i < num_data_; ++i) {
queries_[i] = 0;
}
query_load_from_file_ = false;
}
}
......@@ -185,87 +190,92 @@ void Metadata::CheckOrPartition(data_size_t num_all_data, const std::vector<data
Log::Fatal("Initial score size doesn't match data size");
}
} else {
if (!queries_.empty()) {
Log::Fatal("Cannot used query_id for parallel training");
}
data_size_t num_used_data = static_cast<data_size_t>(used_data_indices.size());
// check weights
if (weights_.size() > 0 && num_weights_ != num_all_data) {
weights_.clear();
num_weights_ = 0;
Log::Fatal("Weights size doesn't match data size");
}
// check query boundries
if (!query_boundaries_.empty() && query_boundaries_[num_queries_] != num_all_data) {
query_boundaries_.clear();
num_queries_ = 0;
Log::Fatal("Query size doesn't match data size");
}
// contain initial score file
if (!init_score_.empty() && (num_init_score_ % num_all_data) != 0) {
init_score_.clear();
num_init_score_ = 0;
Log::Fatal("Initial score size doesn't match data size");
}
// get local weights
if (!weights_.empty()) {
auto old_weights = weights_;
num_weights_ = num_data_;
weights_ = std::vector<float>(num_data_);
if (weight_load_from_file_) {
if (weights_.size() > 0 && num_weights_ != num_all_data) {
weights_.clear();
num_weights_ = 0;
Log::Fatal("Weights size doesn't match data size");
}
// get local weights
if (!weights_.empty()) {
auto old_weights = weights_;
num_weights_ = num_data_;
weights_ = std::vector<float>(num_data_);
#pragma omp parallel for schedule(static)
for (int i = 0; i < static_cast<int>(used_data_indices.size()); ++i) {
weights_[i] = old_weights[used_data_indices[i]];
for (int i = 0; i < static_cast<int>(used_data_indices.size()); ++i) {
weights_[i] = old_weights[used_data_indices[i]];
}
old_weights.clear();
}
old_weights.clear();
}
// get local query boundaries
if (!query_boundaries_.empty()) {
std::vector<data_size_t> used_query;
data_size_t data_idx = 0;
for (data_size_t qid = 0; qid < num_queries_ && data_idx < num_used_data; ++qid) {
data_size_t start = query_boundaries_[qid];
data_size_t end = query_boundaries_[qid + 1];
data_size_t len = end - start;
if (used_data_indices[data_idx] > start) {
continue;
} else if (used_data_indices[data_idx] == start) {
if (num_used_data >= data_idx + len && used_data_indices[data_idx + len - 1] == end - 1) {
used_query.push_back(qid);
data_idx += len;
if (query_load_from_file_) {
// check query boundries
if (!query_boundaries_.empty() && query_boundaries_[num_queries_] != num_all_data) {
query_boundaries_.clear();
num_queries_ = 0;
Log::Fatal("Query size doesn't match data size");
}
// get local query boundaries
if (!query_boundaries_.empty()) {
std::vector<data_size_t> used_query;
data_size_t data_idx = 0;
for (data_size_t qid = 0; qid < num_queries_ && data_idx < num_used_data; ++qid) {
data_size_t start = query_boundaries_[qid];
data_size_t end = query_boundaries_[qid + 1];
data_size_t len = end - start;
if (used_data_indices[data_idx] > start) {
continue;
} else if (used_data_indices[data_idx] == start) {
if (num_used_data >= data_idx + len && used_data_indices[data_idx + len - 1] == end - 1) {
used_query.push_back(qid);
data_idx += len;
} else {
Log::Fatal("Data partition error, data didn't match queries");
}
} else {
Log::Fatal("Data partition error, data didn't match queries");
}
} else {
Log::Fatal("Data partition error, data didn't match queries");
}
auto old_query_boundaries = query_boundaries_;
query_boundaries_ = std::vector<data_size_t>(used_query.size() + 1);
num_queries_ = static_cast<data_size_t>(used_query.size());
query_boundaries_[0] = 0;
for (data_size_t i = 0; i < num_queries_; ++i) {
data_size_t qid = used_query[i];
data_size_t len = old_query_boundaries[qid + 1] - old_query_boundaries[qid];
query_boundaries_[i + 1] = query_boundaries_[i] + len;
}
old_query_boundaries.clear();
}
auto old_query_boundaries = query_boundaries_;
query_boundaries_ = std::vector<data_size_t>(used_query.size() + 1);
num_queries_ = static_cast<data_size_t>(used_query.size());
query_boundaries_[0] = 0;
for (data_size_t i = 0; i < num_queries_; ++i) {
data_size_t qid = used_query[i];
data_size_t len = old_query_boundaries[qid + 1] - old_query_boundaries[qid];
query_boundaries_[i + 1] = query_boundaries_[i] + len;
}
old_query_boundaries.clear();
}
if (init_score_load_from_file_) {
// contain initial score file
if (!init_score_.empty() && (num_init_score_ % num_all_data) != 0) {
init_score_.clear();
num_init_score_ = 0;
Log::Fatal("Initial score size doesn't match data size");
}
// get local initial scores
if (!init_score_.empty()) {
auto old_scores = init_score_;
int num_class = static_cast<int>(num_init_score_ / num_all_data);
num_init_score_ = static_cast<int64_t>(num_data_) * num_class;
init_score_ = std::vector<double>(num_init_score_);
// get local initial scores
if (!init_score_.empty()) {
auto old_scores = init_score_;
int num_class = static_cast<int>(num_init_score_ / num_all_data);
num_init_score_ = static_cast<int64_t>(num_data_) * num_class;
init_score_ = std::vector<double>(num_init_score_);
#pragma omp parallel for schedule(static)
for (int k = 0; k < num_class; ++k){
for (size_t i = 0; i < used_data_indices.size(); ++i) {
init_score_[k * num_data_ + i] = old_scores[k * num_all_data + used_data_indices[i]];
for (int k = 0; k < num_class; ++k) {
for (size_t i = 0; i < used_data_indices.size(); ++i) {
init_score_[k * num_data_ + i] = old_scores[k * num_all_data + used_data_indices[i]];
}
}
old_scores.clear();
}
old_scores.clear();
}
// re-load query weight
LoadQueryWeights();
}
......@@ -289,6 +299,7 @@ void Metadata::SetInitScore(const double* init_score, data_size_t len) {
for (int64_t i = 0; i < num_init_score_; ++i) {
init_score_[i] = init_score[i];
}
init_score_load_from_file_ = false;
}
void Metadata::SetLabel(const float* label, data_size_t len) {
......@@ -326,6 +337,7 @@ void Metadata::SetWeights(const float* weights, data_size_t len) {
weights_[i] = weights[i];
}
LoadQueryWeights();
weight_load_from_file_ = false;
}
void Metadata::SetQuery(const data_size_t* query, data_size_t len) {
......@@ -352,48 +364,7 @@ void Metadata::SetQuery(const data_size_t* query, data_size_t len) {
query_boundaries_[i + 1] = query_boundaries_[i] + query[i];
}
LoadQueryWeights();
}
void Metadata::SetQueryId(const data_size_t* query_id, data_size_t len) {
std::lock_guard<std::mutex> lock(mutex_);
// save to nullptr
if (query_id == nullptr || len == 0) {
query_boundaries_.clear();
queries_.clear();
num_queries_ = 0;
return;
}
if (num_data_ != len) {
Log::Fatal("len of query id is not same with #data");
}
if (!queries_.empty()) { queries_.clear(); }
queries_ = std::vector<data_size_t>(num_data_);
for (data_size_t i = 0; i < num_weights_; ++i) {
queries_[i] = query_id[i];
}
// need convert query_id to boundaries
std::vector<data_size_t> tmp_buffer;
data_size_t last_qid = -1;
data_size_t cur_cnt = 0;
for (data_size_t i = 0; i < num_data_; ++i) {
if (last_qid != queries_[i]) {
if (cur_cnt > 0) {
tmp_buffer.push_back(cur_cnt);
}
cur_cnt = 0;
last_qid = queries_[i];
}
++cur_cnt;
}
tmp_buffer.push_back(cur_cnt);
query_boundaries_ = std::vector<data_size_t>(tmp_buffer.size() + 1);
num_queries_ = static_cast<data_size_t>(tmp_buffer.size());
query_boundaries_[0] = 0;
for (size_t i = 0; i < tmp_buffer.size(); ++i) {
query_boundaries_[i + 1] = query_boundaries_[i] + tmp_buffer[i];
}
queries_.clear();
LoadQueryWeights();
query_load_from_file_ = false;
}
void Metadata::LoadWeights() {
......@@ -415,6 +386,7 @@ void Metadata::LoadWeights() {
Common::Atof(reader.Lines()[i].c_str(), &tmp_weight);
weights_[i] = static_cast<float>(tmp_weight);
}
weight_load_from_file_ = true;
}
void Metadata::LoadInitialScore() {
......@@ -457,6 +429,7 @@ void Metadata::LoadInitialScore() {
}
}
}
init_score_load_from_file_ = true;
}
void Metadata::LoadQueryBoundaries() {
......@@ -478,6 +451,7 @@ void Metadata::LoadQueryBoundaries() {
Common::Atoi(reader.Lines()[i].c_str(), &tmp_cnt);
query_boundaries_[i + 1] = query_boundaries_[i] + static_cast<data_size_t>(tmp_cnt);
}
query_load_from_file_ = true;
}
void Metadata::LoadQueryWeights() {
......@@ -516,12 +490,14 @@ void Metadata::LoadFromMemory(const void* memory) {
weights_ = std::vector<float>(num_weights_);
std::memcpy(weights_.data(), mem_ptr, sizeof(float)*num_weights_);
mem_ptr += sizeof(float)*num_weights_;
weight_load_from_file_ = true;
}
if (num_queries_ > 0) {
if (!query_boundaries_.empty()) { query_boundaries_.clear(); }
query_boundaries_ = std::vector<data_size_t>(num_queries_ + 1);
std::memcpy(query_boundaries_.data(), mem_ptr, sizeof(data_size_t)*(num_queries_ + 1));
mem_ptr += sizeof(data_size_t)*(num_queries_ + 1);
query_load_from_file_ = true;
}
LoadQueryWeights();
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment