dataset_loader.cpp 53.5 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
#include <LightGBM/dataset_loader.h>

Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/network.h>
8
9
10
#include <LightGBM/utils/array_args.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
11

12
13
#include <fstream>

14
15
#include <LightGBM/json11.hpp>

16
17
using namespace json11;

Guolin Ke's avatar
Guolin Ke committed
18
19
namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
20
21
DatasetLoader::DatasetLoader(const Config& io_config, const PredictFunction& predict_fun, int num_class, const char* filename)
  :config_(io_config), random_(config_.data_random_seed), predict_fun_(predict_fun), num_class_(num_class) {
Guolin Ke's avatar
Guolin Ke committed
22
23
24
25
  label_idx_ = 0;
  weight_idx_ = NO_SPECIFIC;
  group_idx_ = NO_SPECIFIC;
  SetHeader(filename);
Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
30
}

DatasetLoader::~DatasetLoader() {
}

Guolin Ke's avatar
Guolin Ke committed
31
void DatasetLoader::SetHeader(const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
32
  std::unordered_map<std::string, int> name2idx;
Guolin Ke's avatar
Guolin Ke committed
33
34
  std::string name_prefix("name:");
  if (filename != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
35
    TextReader<data_size_t> text_reader(filename, config_.header);
Guolin Ke's avatar
Guolin Ke committed
36

Guolin Ke's avatar
Guolin Ke committed
37
    // get column names
Guolin Ke's avatar
Guolin Ke committed
38
    if (config_.header) {
Guolin Ke's avatar
Guolin Ke committed
39
      std::string first_line = text_reader.first_line();
40
      feature_names_ = Common::Split(first_line.c_str(), "\t,");
Guolin Ke's avatar
Guolin Ke committed
41
42
    }

Guolin Ke's avatar
Guolin Ke committed
43
    // load label idx first
Guolin Ke's avatar
Guolin Ke committed
44
45
46
    if (config_.label_column.size() > 0) {
      if (Common::StartsWith(config_.label_column, name_prefix)) {
        std::string name = config_.label_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
47
48
49
50
51
52
53
54
55
56
        label_idx_ = -1;
        for (int i = 0; i < static_cast<int>(feature_names_.size()); ++i) {
          if (name == feature_names_[i]) {
            label_idx_ = i;
            break;
          }
        }
        if (label_idx_ >= 0) {
          Log::Info("Using column %s as label", name.c_str());
        } else {
57
58
          Log::Fatal("Could not find label column %s in data file \n"
                     "or data file doesn't contain header", name.c_str());
Guolin Ke's avatar
Guolin Ke committed
59
        }
Guolin Ke's avatar
Guolin Ke committed
60
      } else {
Guolin Ke's avatar
Guolin Ke committed
61
        if (!Common::AtoiAndCheck(config_.label_column.c_str(), &label_idx_)) {
62
63
64
          Log::Fatal("label_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
65
66
        }
        Log::Info("Using column number %d as label", label_idx_);
Guolin Ke's avatar
Guolin Ke committed
67
68
      }
    }
Guolin Ke's avatar
Guolin Ke committed
69

Guolin Ke's avatar
Guolin Ke committed
70
    if (!feature_names_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
71
72
73
74
      // erase label column name
      feature_names_.erase(feature_names_.begin() + label_idx_);
      for (size_t i = 0; i < feature_names_.size(); ++i) {
        name2idx[feature_names_[i]] = static_cast<int>(i);
Guolin Ke's avatar
Guolin Ke committed
75
      }
Guolin Ke's avatar
Guolin Ke committed
76
77
78
    }

    // load ignore columns
Guolin Ke's avatar
Guolin Ke committed
79
80
81
    if (config_.ignore_column.size() > 0) {
      if (Common::StartsWith(config_.ignore_column, name_prefix)) {
        std::string names = config_.ignore_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
82
83
84
85
86
87
88
89
90
        for (auto name : Common::Split(names.c_str(), ',')) {
          if (name2idx.count(name) > 0) {
            int tmp = name2idx[name];
            ignore_features_.emplace(tmp);
          } else {
            Log::Fatal("Could not find ignore column %s in data file", name.c_str());
          }
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
91
        for (auto token : Common::Split(config_.ignore_column.c_str(), ',')) {
Guolin Ke's avatar
Guolin Ke committed
92
93
          int tmp = 0;
          if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
94
95
96
            Log::Fatal("ignore_column is not a number,\n"
                       "if you want to use a column name,\n"
                       "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
97
98
          }
          ignore_features_.emplace(tmp);
Guolin Ke's avatar
Guolin Ke committed
99
100
101
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
102
    // load weight idx
Guolin Ke's avatar
Guolin Ke committed
103
104
105
    if (config_.weight_column.size() > 0) {
      if (Common::StartsWith(config_.weight_column, name_prefix)) {
        std::string name = config_.weight_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
106
107
108
109
110
111
        if (name2idx.count(name) > 0) {
          weight_idx_ = name2idx[name];
          Log::Info("Using column %s as weight", name.c_str());
        } else {
          Log::Fatal("Could not find weight column %s in data file", name.c_str());
        }
Guolin Ke's avatar
Guolin Ke committed
112
      } else {
Guolin Ke's avatar
Guolin Ke committed
113
        if (!Common::AtoiAndCheck(config_.weight_column.c_str(), &weight_idx_)) {
114
115
116
          Log::Fatal("weight_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
117
118
        }
        Log::Info("Using column number %d as weight", weight_idx_);
Guolin Ke's avatar
Guolin Ke committed
119
      }
Guolin Ke's avatar
Guolin Ke committed
120
      ignore_features_.emplace(weight_idx_);
Guolin Ke's avatar
Guolin Ke committed
121
    }
Guolin Ke's avatar
Guolin Ke committed
122
    // load group idx
Guolin Ke's avatar
Guolin Ke committed
123
124
125
    if (config_.group_column.size() > 0) {
      if (Common::StartsWith(config_.group_column, name_prefix)) {
        std::string name = config_.group_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
126
127
128
129
130
131
132
        if (name2idx.count(name) > 0) {
          group_idx_ = name2idx[name];
          Log::Info("Using column %s as group/query id", name.c_str());
        } else {
          Log::Fatal("Could not find group/query column %s in data file", name.c_str());
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
133
        if (!Common::AtoiAndCheck(config_.group_column.c_str(), &group_idx_)) {
134
135
136
          Log::Fatal("group_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
137
138
139
140
        }
        Log::Info("Using column number %d as group/query id", group_idx_);
      }
      ignore_features_.emplace(group_idx_);
Guolin Ke's avatar
Guolin Ke committed
141
142
    }
  }
Guolin Ke's avatar
Guolin Ke committed
143
144
145
  if (config_.categorical_feature.size() > 0) {
    if (Common::StartsWith(config_.categorical_feature, name_prefix)) {
      std::string names = config_.categorical_feature.substr(name_prefix.size());
146
147
148
149
150
      for (auto name : Common::Split(names.c_str(), ',')) {
        if (name2idx.count(name) > 0) {
          int tmp = name2idx[name];
          categorical_features_.emplace(tmp);
        } else {
Guolin Ke's avatar
Guolin Ke committed
151
          Log::Fatal("Could not find categorical_feature %s in data file", name.c_str());
152
153
154
        }
      }
    } else {
Guolin Ke's avatar
Guolin Ke committed
155
      for (auto token : Common::Split(config_.categorical_feature.c_str(), ',')) {
156
157
        int tmp = 0;
        if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
Guolin Ke's avatar
Guolin Ke committed
158
          Log::Fatal("categorical_feature is not a number,\n"
159
160
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
161
162
163
164
165
        }
        categorical_features_.emplace(tmp);
      }
    }
  }
Guolin Ke's avatar
Guolin Ke committed
166
167
}

168
Dataset* DatasetLoader::LoadFromFile(const char* filename, const char* initscore_file, int rank, int num_machines) {
Guolin Ke's avatar
Guolin Ke committed
169
  // don't support query id in data file when training in parallel
Guolin Ke's avatar
Guolin Ke committed
170
  if (num_machines > 1 && !config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
171
    if (group_idx_ > 0) {
172
173
      Log::Fatal("Using a query id without pre-partitioning the data file is not supported for parallel training.\n"
                 "Please use an additional query file or pre-partition the data");
Guolin Ke's avatar
Guolin Ke committed
174
175
    }
  }
Guolin Ke's avatar
Guolin Ke committed
176
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
Guolin Ke's avatar
Guolin Ke committed
177
178
  data_size_t num_global_data = 0;
  std::vector<data_size_t> used_data_indices;
179
180
  auto bin_filename = CheckCanLoadFromBin(filename);
  if (bin_filename.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
181
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_));
Guolin Ke's avatar
Guolin Ke committed
182
183
184
185
    if (parser == nullptr) {
      Log::Fatal("Could not recognize data format of %s", filename);
    }
    dataset->data_filename_ = filename;
Guolin Ke's avatar
Guolin Ke committed
186
    dataset->label_idx_ = label_idx_;
187
    dataset->metadata_.Init(filename, initscore_file);
Guolin Ke's avatar
Guolin Ke committed
188
    if (!config_.two_round) {
Guolin Ke's avatar
Guolin Ke committed
189
      // read data to memory
190
      auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
191
192
193
194
      dataset->num_data_ = static_cast<data_size_t>(text_data.size());
      // sample data
      auto sample_data = SampleTextDataFromMemory(text_data);
      // construct feature bin mappers
Guolin Ke's avatar
Guolin Ke committed
195
      ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
196
      // initialize label
197
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
198
      // extract features
Guolin Ke's avatar
Guolin Ke committed
199
      ExtractFeaturesFromMemory(&text_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
200
201
202
203
204
205
206
207
208
209
      text_data.clear();
    } else {
      // sample data from file
      auto sample_data = SampleTextDataFromFile(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
      if (used_data_indices.size() > 0) {
        dataset->num_data_ = static_cast<data_size_t>(used_data_indices.size());
      } else {
        dataset->num_data_ = num_global_data;
      }
      // construct feature bin mappers
Guolin Ke's avatar
Guolin Ke committed
210
      ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
211
      // initialize label
212
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
213
      Log::Debug("Making second pass...");
Guolin Ke's avatar
Guolin Ke committed
214
      // extract features
Guolin Ke's avatar
Guolin Ke committed
215
      ExtractFeaturesFromFile(filename, parser.get(), used_data_indices, dataset.get());
Guolin Ke's avatar
Guolin Ke committed
216
217
218
    }
  } else {
    // load data from binary file
219
    dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), rank, num_machines, &num_global_data, &used_data_indices));
Guolin Ke's avatar
Guolin Ke committed
220
221
222
223
  }
  // check meta data
  dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);
  // need to check training data
Guolin Ke's avatar
Guolin Ke committed
224
225
  CheckDataset(dataset.get());
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
226
227
228
229
}



230
Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, const char* initscore_file, const Dataset* train_data) {
Guolin Ke's avatar
Guolin Ke committed
231
232
  data_size_t num_global_data = 0;
  std::vector<data_size_t> used_data_indices;
Guolin Ke's avatar
Guolin Ke committed
233
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
234
235
  auto bin_filename = CheckCanLoadFromBin(filename);
  if (bin_filename.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
236
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_));
Guolin Ke's avatar
Guolin Ke committed
237
238
239
240
    if (parser == nullptr) {
      Log::Fatal("Could not recognize data format of %s", filename);
    }
    dataset->data_filename_ = filename;
Guolin Ke's avatar
Guolin Ke committed
241
    dataset->label_idx_ = label_idx_;
242
    dataset->metadata_.Init(filename, initscore_file);
Guolin Ke's avatar
Guolin Ke committed
243
    if (!config_.two_round) {
Guolin Ke's avatar
Guolin Ke committed
244
245
246
247
      // read data in memory
      auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, 0, 1, &num_global_data, &used_data_indices);
      dataset->num_data_ = static_cast<data_size_t>(text_data.size());
      // initialize label
248
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
249
      dataset->CreateValid(train_data);
Guolin Ke's avatar
Guolin Ke committed
250
      // extract features
Guolin Ke's avatar
Guolin Ke committed
251
      ExtractFeaturesFromMemory(&text_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
252
253
      text_data.clear();
    } else {
Guolin Ke's avatar
Guolin Ke committed
254
      TextReader<data_size_t> text_reader(filename, config_.header);
Guolin Ke's avatar
Guolin Ke committed
255
256
257
258
      // Get number of lines of data file
      dataset->num_data_ = static_cast<data_size_t>(text_reader.CountLine());
      num_global_data = dataset->num_data_;
      // initialize label
259
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
260
      dataset->CreateValid(train_data);
Guolin Ke's avatar
Guolin Ke committed
261
      // extract features
Guolin Ke's avatar
Guolin Ke committed
262
      ExtractFeaturesFromFile(filename, parser.get(), used_data_indices, dataset.get());
Guolin Ke's avatar
Guolin Ke committed
263
264
265
    }
  } else {
    // load data from binary file
266
    dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), 0, 1, &num_global_data, &used_data_indices));
Guolin Ke's avatar
Guolin Ke committed
267
268
269
270
  }
  // not need to check validation data
  // check meta data
  dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
271
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
272
273
}

274
275
276
Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* bin_filename,
                                        int rank, int num_machines, int* num_global_data,
                                        std::vector<data_size_t>* used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
277
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
278
  auto reader = VirtualFileReader::Make(bin_filename);
Guolin Ke's avatar
Guolin Ke committed
279
  dataset->data_filename_ = data_filename;
280
  if (!reader->Init()) {
Guolin Ke's avatar
Guolin Ke committed
281
282
283
284
285
    Log::Fatal("Could not read binary data from %s", bin_filename);
  }

  // buffer to read binary file
  size_t buffer_size = 16 * 1024 * 1024;
Guolin Ke's avatar
Guolin Ke committed
286
  auto buffer = std::vector<char>(buffer_size);
287

288
289
  // check token
  size_t size_of_token = std::strlen(Dataset::binary_file_token);
290
291
  size_t read_cnt = reader->Read(buffer.data(), sizeof(char) * size_of_token);
  if (read_cnt != sizeof(char) * size_of_token) {
292
293
294
    Log::Fatal("Binary file error: token has the wrong size");
  }
  if (std::string(buffer.data()) != std::string(Dataset::binary_file_token)) {
295
    Log::Fatal("Input file is not LightGBM binary file");
296
  }
Guolin Ke's avatar
Guolin Ke committed
297
298

  // read size of header
299
  read_cnt = reader->Read(buffer.data(), sizeof(size_t));
Guolin Ke's avatar
Guolin Ke committed
300

301
  if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
302
303
304
    Log::Fatal("Binary file error: header has the wrong size");
  }

Guolin Ke's avatar
Guolin Ke committed
305
  size_t size_of_head = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
306
307
308
309

  // re-allocmate space if not enough
  if (size_of_head > buffer_size) {
    buffer_size = size_of_head;
Guolin Ke's avatar
Guolin Ke committed
310
    buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
311
312
  }
  // read header
313
  read_cnt = reader->Read(buffer.data(), size_of_head);
Guolin Ke's avatar
Guolin Ke committed
314
315
316
317
318

  if (read_cnt != size_of_head) {
    Log::Fatal("Binary file error: header is incorrect");
  }
  // get header
Guolin Ke's avatar
Guolin Ke committed
319
  const char* mem_ptr = buffer.data();
Guolin Ke's avatar
Guolin Ke committed
320
321
322
323
324
325
  dataset->num_data_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
  mem_ptr += sizeof(dataset->num_data_);
  dataset->num_features_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->num_features_);
  dataset->num_total_features_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->num_total_features_);
Guolin Ke's avatar
Guolin Ke committed
326
327
  dataset->label_idx_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->label_idx_);
328
329
330
331
332
333
334
335
336
337
  dataset->max_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->max_bin_);
  dataset->bin_construct_sample_cnt_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->bin_construct_sample_cnt_);
  dataset->min_data_in_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->min_data_in_bin_);
  dataset->use_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
  mem_ptr += sizeof(dataset->use_missing_);
  dataset->zero_as_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
  mem_ptr += sizeof(dataset->zero_as_missing_);
Guolin Ke's avatar
Guolin Ke committed
338
339
  const int* tmp_feature_map = reinterpret_cast<const int*>(mem_ptr);
  dataset->used_feature_map_.clear();
Guolin Ke's avatar
Guolin Ke committed
340
  for (int i = 0; i < dataset->num_total_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
341
342
    dataset->used_feature_map_.push_back(tmp_feature_map[i]);
  }
Guolin Ke's avatar
Guolin Ke committed
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
  mem_ptr += sizeof(int) * dataset->num_total_features_;
  // num_groups
  dataset->num_groups_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->num_groups_);
  // real_feature_idx_
  const int* tmp_ptr_real_feature_idx_ = reinterpret_cast<const int*>(mem_ptr);
  dataset->real_feature_idx_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->real_feature_idx_.push_back(tmp_ptr_real_feature_idx_[i]);
  }
  mem_ptr += sizeof(int) * dataset->num_features_;
  // feature2group
  const int* tmp_ptr_feature2group = reinterpret_cast<const int*>(mem_ptr);
  dataset->feature2group_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->feature2group_.push_back(tmp_ptr_feature2group[i]);
  }
  mem_ptr += sizeof(int) * dataset->num_features_;
  // feature2subfeature
  const int* tmp_ptr_feature2subfeature = reinterpret_cast<const int*>(mem_ptr);
  dataset->feature2subfeature_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->feature2subfeature_.push_back(tmp_ptr_feature2subfeature[i]);
  }
  mem_ptr += sizeof(int) * dataset->num_features_;
  // group_bin_boundaries
  const uint64_t* tmp_ptr_group_bin_boundaries = reinterpret_cast<const uint64_t*>(mem_ptr);
  dataset->group_bin_boundaries_.clear();
  for (int i = 0; i < dataset->num_groups_ + 1; ++i) {
    dataset->group_bin_boundaries_.push_back(tmp_ptr_group_bin_boundaries[i]);
  }
  mem_ptr += sizeof(uint64_t) * (dataset->num_groups_ + 1);

  // group_feature_start_
  const int* tmp_ptr_group_feature_start = reinterpret_cast<const int*>(mem_ptr);
  dataset->group_feature_start_.clear();
379
  for (int i = 0; i < dataset->num_groups_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
380
381
382
383
384
385
386
387
388
389
390
391
    dataset->group_feature_start_.push_back(tmp_ptr_group_feature_start[i]);
  }
  mem_ptr += sizeof(int) * (dataset->num_groups_);

  // group_feature_cnt_
  const int* tmp_ptr_group_feature_cnt = reinterpret_cast<const int*>(mem_ptr);
  dataset->group_feature_cnt_.clear();
  for (int i = 0; i < dataset->num_groups_; ++i) {
    dataset->group_feature_cnt_.push_back(tmp_ptr_group_feature_cnt[i]);
  }
  mem_ptr += sizeof(int) * (dataset->num_groups_);

392
  if (!config_.monotone_constraints.empty()) {
393
    CHECK(static_cast<size_t>(dataset->num_total_features_) == config_.monotone_constraints.size());
394
    dataset->monotone_types_.resize(dataset->num_features_);
395
    for (int i = 0; i < dataset->num_total_features_; ++i) {
396
      int inner_fidx = dataset->InnerFeatureIndex(i);
397
      if (inner_fidx >= 0) {
398
399
400
        dataset->monotone_types_[inner_fidx] = config_.monotone_constraints[i];
      }
    }
401
  } else {
402
403
404
405
406
    const int8_t* tmp_ptr_monotone_type = reinterpret_cast<const int8_t*>(mem_ptr);
    dataset->monotone_types_.clear();
    for (int i = 0; i < dataset->num_features_; ++i) {
      dataset->monotone_types_.push_back(tmp_ptr_monotone_type[i]);
    }
Guolin Ke's avatar
Guolin Ke committed
407
408
409
410
411
412
413
  }
  mem_ptr += sizeof(int8_t) * (dataset->num_features_);

  if (ArrayArgs<int8_t>::CheckAllZero(dataset->monotone_types_)) {
    dataset->monotone_types_.clear();
  }

414
  if (!config_.feature_contri.empty()) {
415
    CHECK(static_cast<size_t>(dataset->num_total_features_) == config_.feature_contri.size());
416
    dataset->feature_penalty_.resize(dataset->num_features_);
417
    for (int i = 0; i < dataset->num_total_features_; ++i) {
418
      int inner_fidx = dataset->InnerFeatureIndex(i);
419
      if (inner_fidx >= 0) {
420
421
422
        dataset->feature_penalty_[inner_fidx] = config_.feature_contri[i];
      }
    }
423
  } else {
424
425
426
427
428
    const double* tmp_ptr_feature_penalty = reinterpret_cast<const double*>(mem_ptr);
    dataset->feature_penalty_.clear();
    for (int i = 0; i < dataset->num_features_; ++i) {
      dataset->feature_penalty_.push_back(tmp_ptr_feature_penalty[i]);
    }
Guolin Ke's avatar
Guolin Ke committed
429
430
431
432
433
434
435
  }
  mem_ptr += sizeof(double) * (dataset->num_features_);

  if (ArrayArgs<double>::CheckAll(dataset->feature_penalty_, 1)) {
    dataset->feature_penalty_.clear();
  }

Belinda Trotta's avatar
Belinda Trotta committed
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
  if (!config_.max_bin_by_feature.empty()) {
    CHECK(static_cast<size_t>(dataset->num_total_features_) == config_.max_bin_by_feature.size());
    CHECK(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())) > 1);
    dataset->max_bin_by_feature_.resize(dataset->num_total_features_);
    dataset->max_bin_by_feature_.assign(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end());
  } else {
    const int32_t* tmp_ptr_max_bin_by_feature = reinterpret_cast<const int32_t*>(mem_ptr);
    dataset->max_bin_by_feature_.clear();
    for (int i = 0; i < dataset->num_total_features_; ++i) {
      dataset->max_bin_by_feature_.push_back(tmp_ptr_max_bin_by_feature[i]);
    }
  }
  mem_ptr += sizeof(int32_t) * (dataset->num_total_features_);
  if (ArrayArgs<int32_t>::CheckAll(dataset->max_bin_by_feature_, -1)) {
    dataset->max_bin_by_feature_.clear();
  }

Guolin Ke's avatar
Guolin Ke committed
453
  // get feature names
Guolin Ke's avatar
Guolin Ke committed
454
  dataset->feature_names_.clear();
Guolin Ke's avatar
Guolin Ke committed
455
456
457
458
459
460
461
462
463
464
  // write feature names
  for (int i = 0; i < dataset->num_total_features_; ++i) {
    int str_len = *(reinterpret_cast<const int*>(mem_ptr));
    mem_ptr += sizeof(int);
    std::stringstream str_buf;
    for (int j = 0; j < str_len; ++j) {
      char tmp_char = *(reinterpret_cast<const char*>(mem_ptr));
      mem_ptr += sizeof(char);
      str_buf << tmp_char;
    }
Guolin Ke's avatar
Guolin Ke committed
465
    dataset->feature_names_.emplace_back(str_buf.str());
Guolin Ke's avatar
Guolin Ke committed
466
  }
467
468
469
470
471
472
473
474
475
476
477
478
479
  // get forced_bin_bounds_
  dataset->forced_bin_bounds_ = std::vector<std::vector<double>>(dataset->num_total_features_, std::vector<double>());
  for (int i = 0; i < dataset->num_total_features_; ++i) {
    int num_bounds = *(reinterpret_cast<const int*>(mem_ptr));
    mem_ptr += sizeof(int);
    dataset->forced_bin_bounds_[i] = std::vector<double>();
    const double* tmp_ptr_forced_bounds = reinterpret_cast<const double*>(mem_ptr);
    for (int j = 0; j < num_bounds; ++j) {
      double bound = tmp_ptr_forced_bounds[j];
      dataset->forced_bin_bounds_[i].push_back(bound);
    }
    mem_ptr += num_bounds * sizeof(double);
  }
Guolin Ke's avatar
Guolin Ke committed
480
481

  // read size of meta data
482
  read_cnt = reader->Read(buffer.data(), sizeof(size_t));
Guolin Ke's avatar
Guolin Ke committed
483

484
  if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
485
486
487
    Log::Fatal("Binary file error: meta data has the wrong size");
  }

Guolin Ke's avatar
Guolin Ke committed
488
  size_t size_of_metadata = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
489
490
491
492

  // re-allocate space if not enough
  if (size_of_metadata > buffer_size) {
    buffer_size = size_of_metadata;
Guolin Ke's avatar
Guolin Ke committed
493
    buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
494
495
  }
  //  read meta data
496
  read_cnt = reader->Read(buffer.data(), size_of_metadata);
Guolin Ke's avatar
Guolin Ke committed
497
498
499
500
501

  if (read_cnt != size_of_metadata) {
    Log::Fatal("Binary file error: meta data is incorrect");
  }
  // load meta data
Guolin Ke's avatar
Guolin Ke committed
502
  dataset->metadata_.LoadFromMemory(buffer.data());
Guolin Ke's avatar
Guolin Ke committed
503

504
505
  *num_global_data = dataset->num_data_;
  used_data_indices->clear();
Guolin Ke's avatar
Guolin Ke committed
506
  // sample local used data if need to partition
Guolin Ke's avatar
Guolin Ke committed
507
  if (num_machines > 1 && !config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
508
509
510
511
    const data_size_t* query_boundaries = dataset->metadata_.query_boundaries();
    if (query_boundaries == nullptr) {
      // if not contain query file, minimal sample unit is one record
      for (data_size_t i = 0; i < dataset->num_data_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
512
        if (random_.NextShort(0, num_machines) == rank) {
513
          used_data_indices->push_back(i);
Guolin Ke's avatar
Guolin Ke committed
514
515
516
517
518
519
520
521
522
        }
      }
    } else {
      // if contain query file, minimal sample unit is one query
      data_size_t num_queries = dataset->metadata_.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      for (data_size_t i = 0; i < dataset->num_data_; ++i) {
        if (qid >= num_queries) {
523
524
          Log::Fatal("Current query exceeds the range of the query file,\n"
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
525
526
527
528
        }
        if (i >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
529
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
530
531
532
533
534
            is_query_used = true;
          }
          ++qid;
        }
        if (is_query_used) {
535
          used_data_indices->push_back(i);
Guolin Ke's avatar
Guolin Ke committed
536
537
538
        }
      }
    }
539
    dataset->num_data_ = static_cast<data_size_t>((*used_data_indices).size());
Guolin Ke's avatar
Guolin Ke committed
540
  }
541
  dataset->metadata_.PartitionLabel(*used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
542
  // read feature data
Guolin Ke's avatar
Guolin Ke committed
543
  for (int i = 0; i < dataset->num_groups_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
544
    // read feature size
545
546
    read_cnt = reader->Read(buffer.data(), sizeof(size_t));
    if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
547
548
      Log::Fatal("Binary file error: feature %d has the wrong size", i);
    }
Guolin Ke's avatar
Guolin Ke committed
549
    size_t size_of_feature = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
550
551
552
    // re-allocate space if not enough
    if (size_of_feature > buffer_size) {
      buffer_size = size_of_feature;
Guolin Ke's avatar
Guolin Ke committed
553
      buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
554
555
    }

556
    read_cnt = reader->Read(buffer.data(), size_of_feature);
Guolin Ke's avatar
Guolin Ke committed
557
558
559
560

    if (read_cnt != size_of_feature) {
      Log::Fatal("Binary file error: feature %d is incorrect, read count: %d", i, read_cnt);
    }
Guolin Ke's avatar
Guolin Ke committed
561
    dataset->feature_groups_.emplace_back(std::unique_ptr<FeatureGroup>(
562
563
      new FeatureGroup(buffer.data(),
                       *num_global_data,
564
                       *used_data_indices)));
Guolin Ke's avatar
Guolin Ke committed
565
  }
Guolin Ke's avatar
Guolin Ke committed
566
  dataset->feature_groups_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
567
  dataset->is_finish_load_ = true;
Guolin Ke's avatar
Guolin Ke committed
568
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
569
570
}

571

572
573
574
Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
                                               int** sample_indices, int num_col, const int* num_per_col,
                                               size_t total_sample_size, data_size_t num_data) {
575
576
577
578
579
  int num_total_features = num_col;
  if (Network::num_machines() > 1) {
    num_total_features = Network::GlobalSyncUpByMax(num_total_features);
  }
  std::vector<std::unique_ptr<BinMapper>> bin_mappers(num_total_features);
580
581
  // fill feature_names_ if not header
  if (feature_names_.empty()) {
582
    for (int i = 0; i < num_col; ++i) {
583
584
585
586
587
      std::stringstream str_buf;
      str_buf << "Column_" << i;
      feature_names_.push_back(str_buf.str());
    }
  }
Belinda Trotta's avatar
Belinda Trotta committed
588
589
590
591
  if (!config_.max_bin_by_feature.empty()) {
    CHECK(static_cast<size_t>(num_col) == config_.max_bin_by_feature.size());
    CHECK(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())) > 1);
  }
592
593
594
595
596

  // get forced split
  std::string forced_bins_path = config_.forcedbins_filename;
  std::vector<std::vector<double>> forced_bin_bounds = DatasetLoader::GetForcedBins(forced_bins_path, num_col, categorical_features_);

Guolin Ke's avatar
Guolin Ke committed
597
  const data_size_t filter_cnt = static_cast<data_size_t>(
Guolin Ke's avatar
Guolin Ke committed
598
    static_cast<double>(config_.min_data_in_leaf * total_sample_size) / num_data);
599
600
601
602
603
604
605
606
607
608
609
610
611
  if (Network::num_machines() == 1) {
    // if only one machine, find bin locally
    OMP_INIT_EX();
    #pragma omp parallel for schedule(guided)
    for (int i = 0; i < num_col; ++i) {
      OMP_LOOP_EX_BEGIN();
      if (ignore_features_.count(i) > 0) {
        bin_mappers[i] = nullptr;
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(i)) {
        bin_type = BinType::CategoricalBin;
612
613
614
615
        bool feat_is_unconstrained = ((config_.monotone_constraints.size() == 0) || (config_.monotone_constraints[i] == 0));
        if (!feat_is_unconstrained) {
            Log::Fatal("The output cannot be monotone with respect to categorical features");
        }
616
617
      }
      bin_mappers[i].reset(new BinMapper());
Belinda Trotta's avatar
Belinda Trotta committed
618
619
      if (config_.max_bin_by_feature.empty()) {
        bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
620
                                config_.max_bin, config_.min_data_in_bin, filter_cnt,
621
622
                                bin_type, config_.use_missing, config_.zero_as_missing,
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
623
624
      } else {
        bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
625
626
                                config_.max_bin_by_feature[i], config_.min_data_in_bin,
                                filter_cnt, bin_type, config_.use_missing,
627
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
628
      }
629
630
631
632
633
634
635
636
637
638
639
640
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();
  } else {
    // if have multi-machines, need to find bin distributed
    // different machines will find bin for different features
    int num_machines = Network::num_machines();
    int rank = Network::rank();
    // start and len will store the process feature indices for different machines
    // machine i will find bins for features in [ start[i], start[i] + len[i] )
    std::vector<int> start(num_machines);
    std::vector<int> len(num_machines);
641
    int step = (num_total_features + num_machines - 1) / num_machines;
642
643
644
645
    if (step < 1) { step = 1; }

    start[0] = 0;
    for (int i = 0; i < num_machines - 1; ++i) {
646
      len[i] = std::min(step, num_total_features - start[i]);
647
648
      start[i + 1] = start[i] + len[i];
    }
649
    len[num_machines - 1] = num_total_features - start[num_machines - 1];
650
651
652
653
654
655
656
657
658
659
660
661
    OMP_INIT_EX();
    #pragma omp parallel for schedule(guided)
    for (int i = 0; i < len[rank]; ++i) {
      OMP_LOOP_EX_BEGIN();
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(start[rank] + i)) {
        bin_type = BinType::CategoricalBin;
      }
      bin_mappers[i].reset(new BinMapper());
662
663
664
      if (num_col <= start[rank] + i) {
        continue;
      }
Belinda Trotta's avatar
Belinda Trotta committed
665
      if (config_.max_bin_by_feature.empty()) {
666
667
        bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
                                total_sample_size, config_.max_bin, config_.min_data_in_bin,
668
                                filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing,
669
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
670
      } else {
671
672
673
        bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
                                total_sample_size, config_.max_bin_by_feature[start[rank] + i],
                                config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing,
674
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
675
      }
676
677
      OMP_LOOP_EX_END();
    }
678
    comm_size_t self_buf_size = 0;
679
    for (int i = 0; i < len[rank]; ++i) {
680
681
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
682
      }
683
      self_buf_size += static_cast<comm_size_t>(bin_mappers[i]->SizesInByte());
Guolin Ke's avatar
Guolin Ke committed
684
    }
685
686
    std::vector<char> input_buffer(self_buf_size);
    auto cp_ptr = input_buffer.data();
687
688
689
690
    for (int i = 0; i < len[rank]; ++i) {
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
691
692
      bin_mappers[i]->CopyTo(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
693
694
695
      // free
      bin_mappers[i].reset(nullptr);
    }
696
697
698
699
    std::vector<comm_size_t> size_len = Network::GlobalArray(self_buf_size);
    std::vector<comm_size_t> size_start(num_machines, 0);
    for (int i = 1; i < num_machines; ++i) {
      size_start[i] = size_start[i - 1] + size_len[i - 1];
700
    }
701
702
    comm_size_t total_buffer_size = size_start[num_machines - 1] + size_len[num_machines - 1];
    std::vector<char> output_buffer(total_buffer_size);
703
    // gather global feature bin mappers
704
705
    Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), total_buffer_size);
    cp_ptr = output_buffer.data();
706
    // restore features bins from buffer
707
    for (int i = 0; i < num_total_features; ++i) {
708
709
710
711
712
      if (ignore_features_.count(i) > 0) {
        bin_mappers[i] = nullptr;
        continue;
      }
      bin_mappers[i].reset(new BinMapper());
713
714
      bin_mappers[i]->CopyFrom(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
715
    }
Guolin Ke's avatar
Guolin Ke committed
716
  }
Guolin Ke's avatar
Guolin Ke committed
717
  auto dataset = std::unique_ptr<Dataset>(new Dataset(num_data));
Guolin Ke's avatar
Guolin Ke committed
718
  dataset->Construct(&bin_mappers, num_total_features, forced_bin_bounds, sample_indices, sample_values, num_per_col, num_col, total_sample_size, config_);
719
  dataset->set_feature_names(feature_names_);
Guolin Ke's avatar
Guolin Ke committed
720
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
721
}
Guolin Ke's avatar
Guolin Ke committed
722
723
724
725
726
727


// ---- private functions ----

void DatasetLoader::CheckDataset(const Dataset* dataset) {
  if (dataset->num_data_ <= 0) {
Guolin Ke's avatar
Guolin Ke committed
728
    Log::Fatal("Data file %s is empty", dataset->data_filename_.c_str());
Guolin Ke's avatar
Guolin Ke committed
729
  }
730
731
  if (dataset->feature_names_.size() != static_cast<size_t>(dataset->num_total_features_)) {
    Log::Fatal("Size of feature name error, should be %d, got %d", dataset->num_total_features_,
732
               static_cast<int>(dataset->feature_names_.size()));
733
  }
Guolin Ke's avatar
Guolin Ke committed
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
  bool is_feature_order_by_group = true;
  int last_group = -1;
  int last_sub_feature = -1;
  // if features are ordered, not need to use hist_buf
  for (int i = 0; i < dataset->num_features_; ++i) {
    int group = dataset->feature2group_[i];
    int sub_feature = dataset->feature2subfeature_[i];
    if (group < last_group) {
      is_feature_order_by_group = false;
    } else if (group == last_group) {
      if (sub_feature <= last_sub_feature) {
        is_feature_order_by_group = false;
        break;
      }
    }
    last_group = group;
    last_sub_feature = sub_feature;
  }
  if (!is_feature_order_by_group) {
753
    Log::Fatal("Features in dataset should be ordered by group");
Guolin Ke's avatar
Guolin Ke committed
754
  }
Guolin Ke's avatar
Guolin Ke committed
755
756
757
}

std::vector<std::string> DatasetLoader::LoadTextDataToMemory(const char* filename, const Metadata& metadata,
758
759
                                                             int rank, int num_machines, int* num_global_data,
                                                             std::vector<data_size_t>* used_data_indices) {
760
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
761
  used_data_indices->clear();
Guolin Ke's avatar
Guolin Ke committed
762
  if (num_machines == 1 || config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
763
764
765
766
767
768
769
770
771
    // read all lines
    *num_global_data = text_reader.ReadAllLines();
  } else {  // need partition data
            // get query data
    const data_size_t* query_boundaries = metadata.query_boundaries();

    if (query_boundaries == nullptr) {
      // if not contain query data, minimal sample unit is one record
      *num_global_data = text_reader.ReadAndFilterLines([this, rank, num_machines](data_size_t) {
Guolin Ke's avatar
Guolin Ke committed
772
        if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
773
774
775
776
777
778
779
780
781
782
783
784
785
786
          return true;
        } else {
          return false;
        }
      }, used_data_indices);
    } else {
      // if contain query data, minimal sample unit is one query
      data_size_t num_queries = metadata.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      *num_global_data = text_reader.ReadAndFilterLines(
        [this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries]
      (data_size_t line_idx) {
        if (qid >= num_queries) {
787
788
          Log::Fatal("Current query exceeds the range of the query file,\n"
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
789
790
791
792
        }
        if (line_idx >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
793
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
794
795
796
797
798
799
800
801
802
803
804
805
            is_query_used = true;
          }
          ++qid;
        }
        return is_query_used;
      }, used_data_indices);
    }
  }
  return std::move(text_reader.Lines());
}

std::vector<std::string> DatasetLoader::SampleTextDataFromMemory(const std::vector<std::string>& data) {
Guolin Ke's avatar
Guolin Ke committed
806
  int sample_cnt = config_.bin_construct_sample_cnt;
807
808
  if (static_cast<size_t>(sample_cnt) > data.size()) {
    sample_cnt = static_cast<int>(data.size());
809
  }
810
  auto sample_indices = random_.Sample(static_cast<int>(data.size()), sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
811
  std::vector<std::string> out(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
812
813
  for (size_t i = 0; i < sample_indices.size(); ++i) {
    const size_t idx = sample_indices[i];
Guolin Ke's avatar
Guolin Ke committed
814
    out[i] = data[idx];
Guolin Ke's avatar
Guolin Ke committed
815
816
817
818
  }
  return out;
}

819
820
821
std::vector<std::string> DatasetLoader::SampleTextDataFromFile(const char* filename, const Metadata& metadata,
                                                               int rank, int num_machines, int* num_global_data,
                                                               std::vector<data_size_t>* used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
822
  const data_size_t sample_cnt = static_cast<data_size_t>(config_.bin_construct_sample_cnt);
823
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
824
  std::vector<std::string> out_data;
Guolin Ke's avatar
Guolin Ke committed
825
  if (num_machines == 1 || config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
826
    *num_global_data = static_cast<data_size_t>(text_reader.SampleFromFile(&random_, sample_cnt, &out_data));
Guolin Ke's avatar
Guolin Ke committed
827
828
829
830
831
832
833
  } else {  // need partition data
            // get query data
    const data_size_t* query_boundaries = metadata.query_boundaries();
    if (query_boundaries == nullptr) {
      // if not contain query file, minimal sample unit is one record
      *num_global_data = text_reader.SampleAndFilterFromFile([this, rank, num_machines]
      (data_size_t) {
Guolin Ke's avatar
Guolin Ke committed
834
        if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
835
836
837
838
          return true;
        } else {
          return false;
        }
Guolin Ke's avatar
Guolin Ke committed
839
      }, used_data_indices, &random_, sample_cnt, &out_data);
Guolin Ke's avatar
Guolin Ke committed
840
841
842
843
844
845
846
847
848
    } else {
      // if contain query file, minimal sample unit is one query
      data_size_t num_queries = metadata.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      *num_global_data = text_reader.SampleAndFilterFromFile(
        [this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries]
      (data_size_t line_idx) {
        if (qid >= num_queries) {
849
850
          Log::Fatal("Query id exceeds the range of the query file, "
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
851
852
853
854
        }
        if (line_idx >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
855
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
856
857
858
859
860
            is_query_used = true;
          }
          ++qid;
        }
        return is_query_used;
Guolin Ke's avatar
Guolin Ke committed
861
      }, used_data_indices, &random_, sample_cnt, &out_data);
Guolin Ke's avatar
Guolin Ke committed
862
863
864
865
866
    }
  }
  return out_data;
}

867
868
869
void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
                                                    const std::vector<std::string>& sample_data,
                                                    const Parser* parser, Dataset* dataset) {
Guolin Ke's avatar
Guolin Ke committed
870
  std::vector<std::vector<double>> sample_values;
Guolin Ke's avatar
Guolin Ke committed
871
  std::vector<std::vector<int>> sample_indices;
Guolin Ke's avatar
Guolin Ke committed
872
873
  std::vector<std::pair<int, double>> oneline_features;
  double label;
Guolin Ke's avatar
Guolin Ke committed
874
  for (int i = 0; i < static_cast<int>(sample_data.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
875
876
877
878
    oneline_features.clear();
    // parse features
    parser->ParseOneLine(sample_data[i].c_str(), &oneline_features, &label);
    for (std::pair<int, double>& inner_data : oneline_features) {
879
      if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
Guolin Ke's avatar
Guolin Ke committed
880
881
        sample_values.resize(inner_data.first + 1);
        sample_indices.resize(inner_data.first + 1);
882
      }
Guolin Ke's avatar
Guolin Ke committed
883
      if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
Guolin Ke's avatar
Guolin Ke committed
884
885
        sample_values[inner_data.first].emplace_back(inner_data.second);
        sample_indices[inner_data.first].emplace_back(i);
Guolin Ke's avatar
Guolin Ke committed
886
887
888
889
      }
    }
  }

Guolin Ke's avatar
Guolin Ke committed
890
  dataset->feature_groups_.clear();
891
892
893
894
895
896
  dataset->num_total_features_ = std::max(static_cast<int>(sample_values.size()), parser->NumFeatures());
  if (num_machines > 1) {
    dataset->num_total_features_ = Network::GlobalSyncUpByMax(dataset->num_total_features_);
  }
  if (!feature_names_.empty()) {
    CHECK(dataset->num_total_features_ == static_cast<int>(feature_names_.size()));
897
  }
Guolin Ke's avatar
Guolin Ke committed
898

Belinda Trotta's avatar
Belinda Trotta committed
899
900
901
902
903
  if (!config_.max_bin_by_feature.empty()) {
    CHECK(static_cast<size_t>(dataset->num_total_features_) == config_.max_bin_by_feature.size());
    CHECK(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())) > 1);
  }

904
905
  // get forced split
  std::string forced_bins_path = config_.forcedbins_filename;
906
907
  std::vector<std::vector<double>> forced_bin_bounds = DatasetLoader::GetForcedBins(forced_bins_path,
                                                                                    dataset->num_total_features_,
908
909
                                                                                    categorical_features_);

Guolin Ke's avatar
Guolin Ke committed
910
911
912
913
914
915
  // check the range of label_idx, weight_idx and group_idx
  CHECK(label_idx_ >= 0 && label_idx_ <= dataset->num_total_features_);
  CHECK(weight_idx_ < 0 || weight_idx_ < dataset->num_total_features_);
  CHECK(group_idx_ < 0 || group_idx_ < dataset->num_total_features_);

  // fill feature_names_ if not header
Guolin Ke's avatar
Guolin Ke committed
916
  if (feature_names_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
917
918
919
920
921
922
    for (int i = 0; i < dataset->num_total_features_; ++i) {
      std::stringstream str_buf;
      str_buf << "Column_" << i;
      feature_names_.push_back(str_buf.str());
    }
  }
923
  dataset->set_feature_names(feature_names_);
Guolin Ke's avatar
Guolin Ke committed
924
  std::vector<std::unique_ptr<BinMapper>> bin_mappers(dataset->num_total_features_);
Guolin Ke's avatar
Guolin Ke committed
925
  const data_size_t filter_cnt = static_cast<data_size_t>(
Guolin Ke's avatar
Guolin Ke committed
926
    static_cast<double>(config_.min_data_in_leaf* sample_data.size()) / dataset->num_data_);
Guolin Ke's avatar
Guolin Ke committed
927
928
929
  // start find bins
  if (num_machines == 1) {
    // if only one machine, find bin locally
930
    OMP_INIT_EX();
931
    #pragma omp parallel for schedule(guided)
Guolin Ke's avatar
Guolin Ke committed
932
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
933
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
934
      if (ignore_features_.count(i) > 0) {
Guolin Ke's avatar
Guolin Ke committed
935
        bin_mappers[i] = nullptr;
Guolin Ke's avatar
Guolin Ke committed
936
937
        continue;
      }
938
939
940
941
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(i)) {
        bin_type = BinType::CategoricalBin;
      }
Guolin Ke's avatar
Guolin Ke committed
942
      bin_mappers[i].reset(new BinMapper());
Belinda Trotta's avatar
Belinda Trotta committed
943
944
      if (config_.max_bin_by_feature.empty()) {
        bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
945
                                sample_data.size(), config_.max_bin, config_.min_data_in_bin,
946
947
                                filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing,
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
948
949
      } else {
        bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
950
951
                                sample_data.size(), config_.max_bin_by_feature[i],
                                config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing,
952
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
953
      }
954
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
955
    }
956
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
957
958
  } else {
    // start and len will store the process feature indices for different machines
959
    // machine i will find bins for features in [ start[i], start[i] + len[i] )
Guolin Ke's avatar
Guolin Ke committed
960
961
    std::vector<int> start(num_machines);
    std::vector<int> len(num_machines);
962
    int step = (dataset->num_total_features_ + num_machines - 1) / num_machines;
Guolin Ke's avatar
Guolin Ke committed
963
964
965
966
    if (step < 1) { step = 1; }

    start[0] = 0;
    for (int i = 0; i < num_machines - 1; ++i) {
967
      len[i] = std::min(step, dataset->num_total_features_ - start[i]);
Guolin Ke's avatar
Guolin Ke committed
968
969
      start[i + 1] = start[i] + len[i];
    }
970
    len[num_machines - 1] = dataset->num_total_features_ - start[num_machines - 1];
971
    OMP_INIT_EX();
972
    #pragma omp parallel for schedule(guided)
973
    for (int i = 0; i < len[rank]; ++i) {
974
      OMP_LOOP_EX_BEGIN();
975
976
977
978
979
980
981
982
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(start[rank] + i)) {
        bin_type = BinType::CategoricalBin;
      }
      bin_mappers[i].reset(new BinMapper());
Nikita Titov's avatar
Nikita Titov committed
983
      if (static_cast<int>(sample_values.size()) <= start[rank] + i) {
984
985
        continue;
      }
Belinda Trotta's avatar
Belinda Trotta committed
986
      if (config_.max_bin_by_feature.empty()) {
987
        bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
Belinda Trotta's avatar
Belinda Trotta committed
988
                                static_cast<int>(sample_values[start[rank] + i].size()),
989
                                sample_data.size(), config_.max_bin, config_.min_data_in_bin,
990
                                filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing,
991
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
992
      } else {
993
        bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
Belinda Trotta's avatar
Belinda Trotta committed
994
                                static_cast<int>(sample_values[start[rank] + i].size()),
995
996
                                sample_data.size(), config_.max_bin_by_feature[i],
                                config_.min_data_in_bin, filter_cnt, bin_type,
997
                                config_.use_missing, config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
998
      }
999
      OMP_LOOP_EX_END();
1000
    }
1001
    OMP_THROW_EX();
1002
    comm_size_t self_buf_size = 0;
Guolin Ke's avatar
Guolin Ke committed
1003
    for (int i = 0; i < len[rank]; ++i) {
1004
1005
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
Guolin Ke's avatar
Guolin Ke committed
1006
      }
1007
      self_buf_size += static_cast<comm_size_t>(bin_mappers[i]->SizesInByte());
Guolin Ke's avatar
Guolin Ke committed
1008
    }
1009
1010
    std::vector<char> input_buffer(self_buf_size);
    auto cp_ptr = input_buffer.data();
Guolin Ke's avatar
Guolin Ke committed
1011
    for (int i = 0; i < len[rank]; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1012
1013
1014
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
1015
1016
      bin_mappers[i]->CopyTo(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
1017
1018
      // free
      bin_mappers[i].reset(nullptr);
Guolin Ke's avatar
Guolin Ke committed
1019
    }
1020
1021
1022
1023
    std::vector<comm_size_t> size_len = Network::GlobalArray(self_buf_size);
    std::vector<comm_size_t> size_start(num_machines, 0);
    for (int i = 1; i < num_machines; ++i) {
      size_start[i] = size_start[i - 1] + size_len[i - 1];
Guolin Ke's avatar
Guolin Ke committed
1024
    }
1025
1026
    comm_size_t total_buffer_size = size_start[num_machines - 1] + size_len[num_machines - 1];
    std::vector<char> output_buffer(total_buffer_size);
Guolin Ke's avatar
Guolin Ke committed
1027
    // gather global feature bin mappers
1028
1029
    Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), total_buffer_size);
    cp_ptr = output_buffer.data();
Guolin Ke's avatar
Guolin Ke committed
1030
    // restore features bins from buffer
1031
    for (int i = 0; i < dataset->num_total_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1032
      if (ignore_features_.count(i) > 0) {
Guolin Ke's avatar
Guolin Ke committed
1033
        bin_mappers[i] = nullptr;
Guolin Ke's avatar
Guolin Ke committed
1034
1035
        continue;
      }
Guolin Ke's avatar
Guolin Ke committed
1036
      bin_mappers[i].reset(new BinMapper());
1037
1038
      bin_mappers[i]->CopyFrom(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
Guolin Ke's avatar
Guolin Ke committed
1039
1040
    }
  }
1041
  dataset->Construct(&bin_mappers, dataset->num_total_features_, forced_bin_bounds, Common::Vector2Ptr<int>(&sample_indices).data(),
Guolin Ke's avatar
Guolin Ke committed
1042
                     Common::Vector2Ptr<double>(&sample_values).data(),
1043
                     Common::VectorSize<int>(sample_indices).data(), static_cast<int>(sample_indices.size()), sample_data.size(), config_);
Guolin Ke's avatar
Guolin Ke committed
1044
1045
1046
}

/*! \brief Extract local features from memory */
Guolin Ke's avatar
Guolin Ke committed
1047
void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>* text_data, const Parser* parser, Dataset* dataset) {
Guolin Ke's avatar
Guolin Ke committed
1048
1049
  std::vector<std::pair<int, double>> oneline_features;
  double tmp_label = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
1050
  auto& ref_text_data = *text_data;
Guolin Ke's avatar
Guolin Ke committed
1051
  if (predict_fun_ == nullptr) {
1052
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1053
    // if doesn't need to prediction with initial model
1054
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
Guolin Ke's avatar
Guolin Ke committed
1055
    for (data_size_t i = 0; i < dataset->num_data_; ++i) {
1056
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1057
1058
1059
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
Guolin Ke's avatar
Guolin Ke committed
1060
      parser->ParseOneLine(ref_text_data[i].c_str(), &oneline_features, &tmp_label);
Guolin Ke's avatar
Guolin Ke committed
1061
      // set label
1062
      dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1063
      // free processed line:
Guolin Ke's avatar
Guolin Ke committed
1064
      ref_text_data[i].clear();
Guolin Ke's avatar
Guolin Ke committed
1065
1066
      // shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
      // text_reader_->Lines()[i].shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
1067
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1068
1069
      // push data
      for (auto& inner_data : oneline_features) {
1070
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1071
1072
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1073
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1074
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1075
1076
1077
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
          dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
Guolin Ke's avatar
Guolin Ke committed
1078
1079
        } else {
          if (inner_data.first == weight_idx_) {
1080
            dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1081
1082
1083
1084
1085
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1086
      dataset->FinishOneRow(tid, i, is_feature_added);
1087
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1088
    }
1089
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1090
  } else {
1091
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1092
    // if need to prediction with initial model
1093
    std::vector<double> init_score(dataset->num_data_ * num_class_);
1094
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
Guolin Ke's avatar
Guolin Ke committed
1095
    for (data_size_t i = 0; i < dataset->num_data_; ++i) {
1096
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1097
1098
1099
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
Guolin Ke's avatar
Guolin Ke committed
1100
      parser->ParseOneLine(ref_text_data[i].c_str(), &oneline_features, &tmp_label);
Guolin Ke's avatar
Guolin Ke committed
1101
      // set initial score
Guolin Ke's avatar
Guolin Ke committed
1102
1103
      std::vector<double> oneline_init_score(num_class_);
      predict_fun_(oneline_features, oneline_init_score.data());
1104
      for (int k = 0; k < num_class_; ++k) {
1105
        init_score[k * dataset->num_data_ + i] = static_cast<double>(oneline_init_score[k]);
Guolin Ke's avatar
Guolin Ke committed
1106
1107
      }
      // set label
1108
      dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1109
1110
1111
1112
1113
      // free processed line:
      text_data[i].clear();
      // shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
      // text_reader_->Lines()[i].shrink_to_fit();
      // push data
Guolin Ke's avatar
Guolin Ke committed
1114
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1115
      for (auto& inner_data : oneline_features) {
1116
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1117
1118
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1119
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1120
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1121
1122
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
1123
          dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
Guolin Ke's avatar
Guolin Ke committed
1124
1125
        } else {
          if (inner_data.first == weight_idx_) {
1126
            dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1127
1128
1129
1130
1131
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1132
      dataset->FinishOneRow(tid, i, is_feature_added);
1133
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1134
    }
1135
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1136
    // metadata_ will manage space of init_score
1137
    dataset->metadata_.SetInitScore(init_score.data(), dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1138
  }
Guolin Ke's avatar
Guolin Ke committed
1139
  dataset->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1140
  // text data can be free after loaded feature values
Guolin Ke's avatar
Guolin Ke committed
1141
  text_data->clear();
Guolin Ke's avatar
Guolin Ke committed
1142
1143
1144
}

/*! \brief Extract local features from file */
1145
1146
void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser* parser,
                                            const std::vector<data_size_t>& used_data_indices, Dataset* dataset) {
1147
  std::vector<double> init_score;
Guolin Ke's avatar
Guolin Ke committed
1148
  if (predict_fun_ != nullptr) {
1149
    init_score = std::vector<double>(dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1150
1151
1152
1153
1154
1155
  }
  std::function<void(data_size_t, const std::vector<std::string>&)> process_fun =
    [this, &init_score, &parser, &dataset]
  (data_size_t start_idx, const std::vector<std::string>& lines) {
    std::vector<std::pair<int, double>> oneline_features;
    double tmp_label = 0.0f;
1156
    OMP_INIT_EX();
1157
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
Guolin Ke's avatar
Guolin Ke committed
1158
    for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) {
1159
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1160
1161
1162
1163
1164
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
      parser->ParseOneLine(lines[i].c_str(), &oneline_features, &tmp_label);
      // set initial score
Guolin Ke's avatar
Guolin Ke committed
1165
      if (!init_score.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1166
1167
        std::vector<double> oneline_init_score(num_class_);
        predict_fun_(oneline_features, oneline_init_score.data());
1168
        for (int k = 0; k < num_class_; ++k) {
1169
          init_score[k * dataset->num_data_ + start_idx + i] = static_cast<double>(oneline_init_score[k]);
Guolin Ke's avatar
Guolin Ke committed
1170
1171
1172
        }
      }
      // set label
1173
      dataset->metadata_.SetLabelAt(start_idx + i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1174
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1175
1176
      // push data
      for (auto& inner_data : oneline_features) {
1177
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1178
1179
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1180
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1181
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1182
1183
1184
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
          dataset->feature_groups_[group]->PushData(tid, sub_feature, start_idx + i, inner_data.second);
Guolin Ke's avatar
Guolin Ke committed
1185
1186
        } else {
          if (inner_data.first == weight_idx_) {
1187
            dataset->metadata_.SetWeightAt(start_idx + i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1188
1189
1190
1191
1192
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(start_idx + i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1193
      dataset->FinishOneRow(tid, i, is_feature_added);
1194
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1195
    }
1196
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1197
  };
1198
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
1199
  if (!used_data_indices.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1200
1201
1202
1203
1204
1205
1206
1207
    // only need part of data
    text_reader.ReadPartAndProcessParallel(used_data_indices, process_fun);
  } else {
    // need full data
    text_reader.ReadAllAndProcessParallel(process_fun);
  }

  // metadata_ will manage space of init_score
Guolin Ke's avatar
Guolin Ke committed
1208
  if (!init_score.empty()) {
1209
    dataset->metadata_.SetInitScore(init_score.data(), dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1210
  }
Guolin Ke's avatar
Guolin Ke committed
1211
  dataset->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1212
1213
1214
}

/*! \brief Check can load from binary file */
1215
std::string DatasetLoader::CheckCanLoadFromBin(const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
1216
1217
1218
  std::string bin_filename(filename);
  bin_filename.append(".bin");

1219
  auto reader = VirtualFileReader::Make(bin_filename.c_str());
Guolin Ke's avatar
Guolin Ke committed
1220

1221
  if (!reader->Init()) {
1222
    bin_filename = std::string(filename);
1223
1224
    reader = VirtualFileReader::Make(bin_filename.c_str());
    if (!reader->Init()) {
1225
      Log::Fatal("Cannot open data file %s", bin_filename.c_str());
1226
    }
1227
  }
1228
1229
1230
1231
1232

  size_t buffer_size = 256;
  auto buffer = std::vector<char>(buffer_size);
  // read size of token
  size_t size_of_token = std::strlen(Dataset::binary_file_token);
1233
  size_t read_cnt = reader->Read(buffer.data(), size_of_token);
1234
1235
  if (read_cnt == size_of_token
      && std::string(buffer.data()) == std::string(Dataset::binary_file_token)) {
1236
    return bin_filename;
Guolin Ke's avatar
Guolin Ke committed
1237
  } else {
1238
    return std::string();
Guolin Ke's avatar
Guolin Ke committed
1239
1240
1241
  }
}

1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279


std::vector<std::vector<double>> DatasetLoader::GetForcedBins(std::string forced_bins_path, int num_total_features,
                                                              const std::unordered_set<int>& categorical_features) {
  std::vector<std::vector<double>> forced_bins(num_total_features, std::vector<double>());
  if (forced_bins_path != "") {
    std::ifstream forced_bins_stream(forced_bins_path.c_str());
    if (forced_bins_stream.fail()) {
      Log::Warning("Could not open %s. Will ignore.", forced_bins_path.c_str());
    } else {
      std::stringstream buffer;
      buffer << forced_bins_stream.rdbuf();
      std::string err;
      Json forced_bins_json = Json::parse(buffer.str(), err);
      CHECK(forced_bins_json.is_array());
      std::vector<Json> forced_bins_arr = forced_bins_json.array_items();
      for (size_t i = 0; i < forced_bins_arr.size(); ++i) {
        int feature_num = forced_bins_arr[i]["feature"].int_value();
        CHECK(feature_num < num_total_features);
        if (categorical_features.count(feature_num)) {
          Log::Warning("Feature %d is categorical. Will ignore forced bins for this  feature.", feature_num);
        } else {
          std::vector<Json> bounds_arr = forced_bins_arr[i]["bin_upper_bound"].array_items();
          for (size_t j = 0; j < bounds_arr.size(); ++j) {
            forced_bins[feature_num].push_back(bounds_arr[j].number_value());
          }
        }
      }
      // remove duplicates
      for (int i = 0; i < num_total_features; ++i) {
        auto new_end = std::unique(forced_bins[i].begin(), forced_bins[i].end());
        forced_bins[i].erase(new_end, forced_bins[i].end());
      }
    }
  }
  return forced_bins;
}

1280
}  // namespace LightGBM