Commit 2a788165 authored by Guolin Ke's avatar Guolin Ke
Browse files

fix a bug for the label_index

parent 9ae1f9ac
...@@ -526,7 +526,7 @@ void Dataset::SaveBinaryFile(const char* bin_filename) { ...@@ -526,7 +526,7 @@ void Dataset::SaveBinaryFile(const char* bin_filename) {
fwrite(binary_file_token, sizeof(char), size_of_token, file); fwrite(binary_file_token, sizeof(char), size_of_token, file);
// get size of header // get size of header
size_t size_of_header = sizeof(num_data_) + sizeof(num_features_) + sizeof(num_total_features_) size_t size_of_header = sizeof(num_data_) + sizeof(num_features_) + sizeof(num_total_features_)
+ sizeof(int) * num_total_features_ + sizeof(num_groups_) + sizeof(int) * num_total_features_ + sizeof(label_idx_) + sizeof(num_groups_)
+ 3 * sizeof(int) * num_features_ + sizeof(uint64_t) * (num_groups_ + 1) + 2 * sizeof(int) * num_groups_; + 3 * sizeof(int) * num_features_ + sizeof(uint64_t) * (num_groups_ + 1) + 2 * sizeof(int) * num_groups_;
// size of feature names // size of feature names
for (int i = 0; i < num_total_features_; ++i) { for (int i = 0; i < num_total_features_; ++i) {
...@@ -537,6 +537,7 @@ void Dataset::SaveBinaryFile(const char* bin_filename) { ...@@ -537,6 +537,7 @@ void Dataset::SaveBinaryFile(const char* bin_filename) {
fwrite(&num_data_, sizeof(num_data_), 1, file); fwrite(&num_data_, sizeof(num_data_), 1, file);
fwrite(&num_features_, sizeof(num_features_), 1, file); fwrite(&num_features_, sizeof(num_features_), 1, file);
fwrite(&num_total_features_, sizeof(num_total_features_), 1, file); fwrite(&num_total_features_, sizeof(num_total_features_), 1, file);
fwrite(&label_idx_, sizeof(label_idx_), 1, file);
fwrite(used_feature_map_.data(), sizeof(int), num_total_features_, file); fwrite(used_feature_map_.data(), sizeof(int), num_total_features_, file);
fwrite(&num_groups_, sizeof(num_groups_), 1, file); fwrite(&num_groups_, sizeof(num_groups_), 1, file);
fwrite(real_feature_idx_.data(), sizeof(int), num_features_, file); fwrite(real_feature_idx_.data(), sizeof(int), num_features_, file);
......
...@@ -174,6 +174,7 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac ...@@ -174,6 +174,7 @@ Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_mac
Log::Fatal("Could not recognize data format of %s", filename); Log::Fatal("Could not recognize data format of %s", filename);
} }
dataset->data_filename_ = filename; dataset->data_filename_ = filename;
dataset->label_idx_ = label_idx_;
dataset->metadata_.Init(filename); dataset->metadata_.Init(filename);
if (!io_config_.use_two_round_loading) { if (!io_config_.use_two_round_loading) {
// read data to memory // read data to memory
...@@ -228,6 +229,7 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, ...@@ -228,6 +229,7 @@ Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename,
Log::Fatal("Could not recognize data format of %s", filename); Log::Fatal("Could not recognize data format of %s", filename);
} }
dataset->data_filename_ = filename; dataset->data_filename_ = filename;
dataset->label_idx_ = label_idx_;
dataset->metadata_.Init(filename); dataset->metadata_.Init(filename);
if (!io_config_.use_two_round_loading) { if (!io_config_.use_two_round_loading) {
// read data in memory // read data in memory
...@@ -315,6 +317,8 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b ...@@ -315,6 +317,8 @@ Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* b
mem_ptr += sizeof(dataset->num_features_); mem_ptr += sizeof(dataset->num_features_);
dataset->num_total_features_ = *(reinterpret_cast<const int*>(mem_ptr)); dataset->num_total_features_ = *(reinterpret_cast<const int*>(mem_ptr));
mem_ptr += sizeof(dataset->num_total_features_); mem_ptr += sizeof(dataset->num_total_features_);
dataset->label_idx_ = *(reinterpret_cast<const int*>(mem_ptr));
mem_ptr += sizeof(dataset->label_idx_);
const int* tmp_feature_map = reinterpret_cast<const int*>(mem_ptr); const int* tmp_feature_map = reinterpret_cast<const int*>(mem_ptr);
dataset->used_feature_map_.clear(); dataset->used_feature_map_.clear();
for (int i = 0; i < dataset->num_total_features_; ++i) { for (int i = 0; i < dataset->num_total_features_; ++i) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment