dataset_loader.cpp 52 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
#include <LightGBM/dataset_loader.h>

Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/network.h>
8
#include <LightGBM/utils/array_args.h>
9
#include <LightGBM/utils/json11.h>
10
11
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
12

13
14
#include <fstream>

Guolin Ke's avatar
Guolin Ke committed
15
16
namespace LightGBM {

17
18
using json11::Json;

Guolin Ke's avatar
Guolin Ke committed
19
20
DatasetLoader::DatasetLoader(const Config& io_config, const PredictFunction& predict_fun, int num_class, const char* filename)
  :config_(io_config), random_(config_.data_random_seed), predict_fun_(predict_fun), num_class_(num_class) {
Guolin Ke's avatar
Guolin Ke committed
21
22
23
24
  label_idx_ = 0;
  weight_idx_ = NO_SPECIFIC;
  group_idx_ = NO_SPECIFIC;
  SetHeader(filename);
Guolin Ke's avatar
Guolin Ke committed
25
26
27
28
29
}

DatasetLoader::~DatasetLoader() {
}

Guolin Ke's avatar
Guolin Ke committed
30
void DatasetLoader::SetHeader(const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
31
  std::unordered_map<std::string, int> name2idx;
Guolin Ke's avatar
Guolin Ke committed
32
33
  std::string name_prefix("name:");
  if (filename != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
34
    TextReader<data_size_t> text_reader(filename, config_.header);
Guolin Ke's avatar
Guolin Ke committed
35

Guolin Ke's avatar
Guolin Ke committed
36
    // get column names
Guolin Ke's avatar
Guolin Ke committed
37
    if (config_.header) {
Guolin Ke's avatar
Guolin Ke committed
38
      std::string first_line = text_reader.first_line();
39
      feature_names_ = Common::Split(first_line.c_str(), "\t,");
Guolin Ke's avatar
Guolin Ke committed
40
41
    }

Guolin Ke's avatar
Guolin Ke committed
42
    // load label idx first
Guolin Ke's avatar
Guolin Ke committed
43
44
45
    if (config_.label_column.size() > 0) {
      if (Common::StartsWith(config_.label_column, name_prefix)) {
        std::string name = config_.label_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
46
47
48
49
50
51
52
53
54
55
        label_idx_ = -1;
        for (int i = 0; i < static_cast<int>(feature_names_.size()); ++i) {
          if (name == feature_names_[i]) {
            label_idx_ = i;
            break;
          }
        }
        if (label_idx_ >= 0) {
          Log::Info("Using column %s as label", name.c_str());
        } else {
56
57
          Log::Fatal("Could not find label column %s in data file \n"
                     "or data file doesn't contain header", name.c_str());
Guolin Ke's avatar
Guolin Ke committed
58
        }
Guolin Ke's avatar
Guolin Ke committed
59
      } else {
Guolin Ke's avatar
Guolin Ke committed
60
        if (!Common::AtoiAndCheck(config_.label_column.c_str(), &label_idx_)) {
61
62
63
          Log::Fatal("label_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
64
65
        }
        Log::Info("Using column number %d as label", label_idx_);
Guolin Ke's avatar
Guolin Ke committed
66
67
      }
    }
Guolin Ke's avatar
Guolin Ke committed
68

Guolin Ke's avatar
Guolin Ke committed
69
    if (!feature_names_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
70
71
72
73
      // erase label column name
      feature_names_.erase(feature_names_.begin() + label_idx_);
      for (size_t i = 0; i < feature_names_.size(); ++i) {
        name2idx[feature_names_[i]] = static_cast<int>(i);
Guolin Ke's avatar
Guolin Ke committed
74
      }
Guolin Ke's avatar
Guolin Ke committed
75
76
77
    }

    // load ignore columns
Guolin Ke's avatar
Guolin Ke committed
78
79
80
    if (config_.ignore_column.size() > 0) {
      if (Common::StartsWith(config_.ignore_column, name_prefix)) {
        std::string names = config_.ignore_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
81
82
83
84
85
86
87
88
89
        for (auto name : Common::Split(names.c_str(), ',')) {
          if (name2idx.count(name) > 0) {
            int tmp = name2idx[name];
            ignore_features_.emplace(tmp);
          } else {
            Log::Fatal("Could not find ignore column %s in data file", name.c_str());
          }
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
90
        for (auto token : Common::Split(config_.ignore_column.c_str(), ',')) {
Guolin Ke's avatar
Guolin Ke committed
91
92
          int tmp = 0;
          if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
93
94
95
            Log::Fatal("ignore_column is not a number,\n"
                       "if you want to use a column name,\n"
                       "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
96
97
          }
          ignore_features_.emplace(tmp);
Guolin Ke's avatar
Guolin Ke committed
98
99
100
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
101
    // load weight idx
Guolin Ke's avatar
Guolin Ke committed
102
103
104
    if (config_.weight_column.size() > 0) {
      if (Common::StartsWith(config_.weight_column, name_prefix)) {
        std::string name = config_.weight_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
105
106
107
108
109
110
        if (name2idx.count(name) > 0) {
          weight_idx_ = name2idx[name];
          Log::Info("Using column %s as weight", name.c_str());
        } else {
          Log::Fatal("Could not find weight column %s in data file", name.c_str());
        }
Guolin Ke's avatar
Guolin Ke committed
111
      } else {
Guolin Ke's avatar
Guolin Ke committed
112
        if (!Common::AtoiAndCheck(config_.weight_column.c_str(), &weight_idx_)) {
113
114
115
          Log::Fatal("weight_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
116
117
        }
        Log::Info("Using column number %d as weight", weight_idx_);
Guolin Ke's avatar
Guolin Ke committed
118
      }
Guolin Ke's avatar
Guolin Ke committed
119
      ignore_features_.emplace(weight_idx_);
Guolin Ke's avatar
Guolin Ke committed
120
    }
Guolin Ke's avatar
Guolin Ke committed
121
    // load group idx
Guolin Ke's avatar
Guolin Ke committed
122
123
124
    if (config_.group_column.size() > 0) {
      if (Common::StartsWith(config_.group_column, name_prefix)) {
        std::string name = config_.group_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
125
126
127
128
129
130
131
        if (name2idx.count(name) > 0) {
          group_idx_ = name2idx[name];
          Log::Info("Using column %s as group/query id", name.c_str());
        } else {
          Log::Fatal("Could not find group/query column %s in data file", name.c_str());
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
132
        if (!Common::AtoiAndCheck(config_.group_column.c_str(), &group_idx_)) {
133
134
135
          Log::Fatal("group_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
136
137
138
139
        }
        Log::Info("Using column number %d as group/query id", group_idx_);
      }
      ignore_features_.emplace(group_idx_);
Guolin Ke's avatar
Guolin Ke committed
140
141
    }
  }
Guolin Ke's avatar
Guolin Ke committed
142
143
144
  if (config_.categorical_feature.size() > 0) {
    if (Common::StartsWith(config_.categorical_feature, name_prefix)) {
      std::string names = config_.categorical_feature.substr(name_prefix.size());
145
146
147
148
149
      for (auto name : Common::Split(names.c_str(), ',')) {
        if (name2idx.count(name) > 0) {
          int tmp = name2idx[name];
          categorical_features_.emplace(tmp);
        } else {
Guolin Ke's avatar
Guolin Ke committed
150
          Log::Fatal("Could not find categorical_feature %s in data file", name.c_str());
151
152
153
        }
      }
    } else {
Guolin Ke's avatar
Guolin Ke committed
154
      for (auto token : Common::Split(config_.categorical_feature.c_str(), ',')) {
155
156
        int tmp = 0;
        if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
Guolin Ke's avatar
Guolin Ke committed
157
          Log::Fatal("categorical_feature is not a number,\n"
158
159
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
160
161
162
163
164
        }
        categorical_features_.emplace(tmp);
      }
    }
  }
Guolin Ke's avatar
Guolin Ke committed
165
166
}

167
Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_machines) {
Guolin Ke's avatar
Guolin Ke committed
168
  // don't support query id in data file when training in parallel
Guolin Ke's avatar
Guolin Ke committed
169
  if (num_machines > 1 && !config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
170
    if (group_idx_ > 0) {
171
172
      Log::Fatal("Using a query id without pre-partitioning the data file is not supported for parallel training.\n"
                 "Please use an additional query file or pre-partition the data");
Guolin Ke's avatar
Guolin Ke committed
173
174
    }
  }
Guolin Ke's avatar
Guolin Ke committed
175
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
Guolin Ke's avatar
Guolin Ke committed
176
177
  data_size_t num_global_data = 0;
  std::vector<data_size_t> used_data_indices;
178
179
  auto bin_filename = CheckCanLoadFromBin(filename);
  if (bin_filename.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
180
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_));
Guolin Ke's avatar
Guolin Ke committed
181
182
183
184
    if (parser == nullptr) {
      Log::Fatal("Could not recognize data format of %s", filename);
    }
    dataset->data_filename_ = filename;
Guolin Ke's avatar
Guolin Ke committed
185
    dataset->label_idx_ = label_idx_;
186
    dataset->metadata_.Init(filename);
Guolin Ke's avatar
Guolin Ke committed
187
    if (!config_.two_round) {
Guolin Ke's avatar
Guolin Ke committed
188
      // read data to memory
189
      auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
190
191
192
193
      dataset->num_data_ = static_cast<data_size_t>(text_data.size());
      // sample data
      auto sample_data = SampleTextDataFromMemory(text_data);
      // construct feature bin mappers
Guolin Ke's avatar
Guolin Ke committed
194
      ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
195
      // initialize label
196
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
197
      // extract features
Guolin Ke's avatar
Guolin Ke committed
198
      ExtractFeaturesFromMemory(&text_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
199
200
201
202
203
204
205
206
207
208
      text_data.clear();
    } else {
      // sample data from file
      auto sample_data = SampleTextDataFromFile(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
      if (used_data_indices.size() > 0) {
        dataset->num_data_ = static_cast<data_size_t>(used_data_indices.size());
      } else {
        dataset->num_data_ = num_global_data;
      }
      // construct feature bin mappers
Guolin Ke's avatar
Guolin Ke committed
209
      ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
210
      // initialize label
211
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
212
      Log::Debug("Making second pass...");
Guolin Ke's avatar
Guolin Ke committed
213
      // extract features
Guolin Ke's avatar
Guolin Ke committed
214
      ExtractFeaturesFromFile(filename, parser.get(), used_data_indices, dataset.get());
Guolin Ke's avatar
Guolin Ke committed
215
216
217
    }
  } else {
    // load data from binary file
218
    dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), rank, num_machines, &num_global_data, &used_data_indices));
Guolin Ke's avatar
Guolin Ke committed
219
220
221
222
  }
  // check meta data
  dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);
  // need to check training data
Guolin Ke's avatar
Guolin Ke committed
223
224
  CheckDataset(dataset.get());
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
225
226
227
228
}



229
Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, const Dataset* train_data) {
Guolin Ke's avatar
Guolin Ke committed
230
231
  data_size_t num_global_data = 0;
  std::vector<data_size_t> used_data_indices;
Guolin Ke's avatar
Guolin Ke committed
232
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
233
234
  auto bin_filename = CheckCanLoadFromBin(filename);
  if (bin_filename.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
235
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_));
Guolin Ke's avatar
Guolin Ke committed
236
237
238
239
    if (parser == nullptr) {
      Log::Fatal("Could not recognize data format of %s", filename);
    }
    dataset->data_filename_ = filename;
Guolin Ke's avatar
Guolin Ke committed
240
    dataset->label_idx_ = label_idx_;
241
    dataset->metadata_.Init(filename);
Guolin Ke's avatar
Guolin Ke committed
242
    if (!config_.two_round) {
Guolin Ke's avatar
Guolin Ke committed
243
244
245
246
      // read data in memory
      auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, 0, 1, &num_global_data, &used_data_indices);
      dataset->num_data_ = static_cast<data_size_t>(text_data.size());
      // initialize label
247
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
248
      dataset->CreateValid(train_data);
Guolin Ke's avatar
Guolin Ke committed
249
      // extract features
Guolin Ke's avatar
Guolin Ke committed
250
      ExtractFeaturesFromMemory(&text_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
251
252
      text_data.clear();
    } else {
Guolin Ke's avatar
Guolin Ke committed
253
      TextReader<data_size_t> text_reader(filename, config_.header);
Guolin Ke's avatar
Guolin Ke committed
254
255
256
257
      // Get number of lines of data file
      dataset->num_data_ = static_cast<data_size_t>(text_reader.CountLine());
      num_global_data = dataset->num_data_;
      // initialize label
258
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
259
      dataset->CreateValid(train_data);
Guolin Ke's avatar
Guolin Ke committed
260
      // extract features
Guolin Ke's avatar
Guolin Ke committed
261
      ExtractFeaturesFromFile(filename, parser.get(), used_data_indices, dataset.get());
Guolin Ke's avatar
Guolin Ke committed
262
263
264
    }
  } else {
    // load data from binary file
265
    dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), 0, 1, &num_global_data, &used_data_indices));
Guolin Ke's avatar
Guolin Ke committed
266
267
268
269
  }
  // not need to check validation data
  // check meta data
  dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
270
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
271
272
}

273
274
275
Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* bin_filename,
                                        int rank, int num_machines, int* num_global_data,
                                        std::vector<data_size_t>* used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
276
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
277
  auto reader = VirtualFileReader::Make(bin_filename);
Guolin Ke's avatar
Guolin Ke committed
278
  dataset->data_filename_ = data_filename;
279
  if (!reader->Init()) {
Guolin Ke's avatar
Guolin Ke committed
280
281
282
283
284
    Log::Fatal("Could not read binary data from %s", bin_filename);
  }

  // buffer to read binary file
  size_t buffer_size = 16 * 1024 * 1024;
Guolin Ke's avatar
Guolin Ke committed
285
  auto buffer = std::vector<char>(buffer_size);
286

287
288
  // check token
  size_t size_of_token = std::strlen(Dataset::binary_file_token);
289
290
  size_t read_cnt = reader->Read(buffer.data(), sizeof(char) * size_of_token);
  if (read_cnt != sizeof(char) * size_of_token) {
291
292
293
    Log::Fatal("Binary file error: token has the wrong size");
  }
  if (std::string(buffer.data()) != std::string(Dataset::binary_file_token)) {
294
    Log::Fatal("Input file is not LightGBM binary file");
295
  }
Guolin Ke's avatar
Guolin Ke committed
296
297

  // read size of header
298
  read_cnt = reader->Read(buffer.data(), sizeof(size_t));
Guolin Ke's avatar
Guolin Ke committed
299

300
  if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
301
302
303
    Log::Fatal("Binary file error: header has the wrong size");
  }

Guolin Ke's avatar
Guolin Ke committed
304
  size_t size_of_head = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
305
306
307
308

  // re-allocmate space if not enough
  if (size_of_head > buffer_size) {
    buffer_size = size_of_head;
Guolin Ke's avatar
Guolin Ke committed
309
    buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
310
311
  }
  // read header
312
  read_cnt = reader->Read(buffer.data(), size_of_head);
Guolin Ke's avatar
Guolin Ke committed
313
314
315
316
317

  if (read_cnt != size_of_head) {
    Log::Fatal("Binary file error: header is incorrect");
  }
  // get header
Guolin Ke's avatar
Guolin Ke committed
318
  const char* mem_ptr = buffer.data();
Guolin Ke's avatar
Guolin Ke committed
319
320
321
322
323
324
  dataset->num_data_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
  mem_ptr += sizeof(dataset->num_data_);
  dataset->num_features_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->num_features_);
  dataset->num_total_features_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->num_total_features_);
Guolin Ke's avatar
Guolin Ke committed
325
326
  dataset->label_idx_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->label_idx_);
327
328
329
330
331
332
333
334
335
336
  dataset->max_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->max_bin_);
  dataset->bin_construct_sample_cnt_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->bin_construct_sample_cnt_);
  dataset->min_data_in_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->min_data_in_bin_);
  dataset->use_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
  mem_ptr += sizeof(dataset->use_missing_);
  dataset->zero_as_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
  mem_ptr += sizeof(dataset->zero_as_missing_);
Guolin Ke's avatar
Guolin Ke committed
337
338
  const int* tmp_feature_map = reinterpret_cast<const int*>(mem_ptr);
  dataset->used_feature_map_.clear();
Guolin Ke's avatar
Guolin Ke committed
339
  for (int i = 0; i < dataset->num_total_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
340
341
    dataset->used_feature_map_.push_back(tmp_feature_map[i]);
  }
Guolin Ke's avatar
Guolin Ke committed
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
  mem_ptr += sizeof(int) * dataset->num_total_features_;
  // num_groups
  dataset->num_groups_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->num_groups_);
  // real_feature_idx_
  const int* tmp_ptr_real_feature_idx_ = reinterpret_cast<const int*>(mem_ptr);
  dataset->real_feature_idx_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->real_feature_idx_.push_back(tmp_ptr_real_feature_idx_[i]);
  }
  mem_ptr += sizeof(int) * dataset->num_features_;
  // feature2group
  const int* tmp_ptr_feature2group = reinterpret_cast<const int*>(mem_ptr);
  dataset->feature2group_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->feature2group_.push_back(tmp_ptr_feature2group[i]);
  }
  mem_ptr += sizeof(int) * dataset->num_features_;
  // feature2subfeature
  const int* tmp_ptr_feature2subfeature = reinterpret_cast<const int*>(mem_ptr);
  dataset->feature2subfeature_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->feature2subfeature_.push_back(tmp_ptr_feature2subfeature[i]);
  }
  mem_ptr += sizeof(int) * dataset->num_features_;
  // group_bin_boundaries
  const uint64_t* tmp_ptr_group_bin_boundaries = reinterpret_cast<const uint64_t*>(mem_ptr);
  dataset->group_bin_boundaries_.clear();
  for (int i = 0; i < dataset->num_groups_ + 1; ++i) {
    dataset->group_bin_boundaries_.push_back(tmp_ptr_group_bin_boundaries[i]);
  }
  mem_ptr += sizeof(uint64_t) * (dataset->num_groups_ + 1);

  // group_feature_start_
  const int* tmp_ptr_group_feature_start = reinterpret_cast<const int*>(mem_ptr);
  dataset->group_feature_start_.clear();
378
  for (int i = 0; i < dataset->num_groups_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
379
380
381
382
383
384
385
386
387
388
389
390
    dataset->group_feature_start_.push_back(tmp_ptr_group_feature_start[i]);
  }
  mem_ptr += sizeof(int) * (dataset->num_groups_);

  // group_feature_cnt_
  const int* tmp_ptr_group_feature_cnt = reinterpret_cast<const int*>(mem_ptr);
  dataset->group_feature_cnt_.clear();
  for (int i = 0; i < dataset->num_groups_; ++i) {
    dataset->group_feature_cnt_.push_back(tmp_ptr_group_feature_cnt[i]);
  }
  mem_ptr += sizeof(int) * (dataset->num_groups_);

Belinda Trotta's avatar
Belinda Trotta committed
391
  if (!config_.max_bin_by_feature.empty()) {
392
393
    CHECK_EQ(static_cast<size_t>(dataset->num_total_features_), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
Belinda Trotta's avatar
Belinda Trotta committed
394
395
396
397
398
399
400
401
402
403
404
405
406
407
    dataset->max_bin_by_feature_.resize(dataset->num_total_features_);
    dataset->max_bin_by_feature_.assign(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end());
  } else {
    const int32_t* tmp_ptr_max_bin_by_feature = reinterpret_cast<const int32_t*>(mem_ptr);
    dataset->max_bin_by_feature_.clear();
    for (int i = 0; i < dataset->num_total_features_; ++i) {
      dataset->max_bin_by_feature_.push_back(tmp_ptr_max_bin_by_feature[i]);
    }
  }
  mem_ptr += sizeof(int32_t) * (dataset->num_total_features_);
  if (ArrayArgs<int32_t>::CheckAll(dataset->max_bin_by_feature_, -1)) {
    dataset->max_bin_by_feature_.clear();
  }

Guolin Ke's avatar
Guolin Ke committed
408
  // get feature names
Guolin Ke's avatar
Guolin Ke committed
409
  dataset->feature_names_.clear();
Guolin Ke's avatar
Guolin Ke committed
410
411
412
413
414
415
416
417
418
419
  // write feature names
  for (int i = 0; i < dataset->num_total_features_; ++i) {
    int str_len = *(reinterpret_cast<const int*>(mem_ptr));
    mem_ptr += sizeof(int);
    std::stringstream str_buf;
    for (int j = 0; j < str_len; ++j) {
      char tmp_char = *(reinterpret_cast<const char*>(mem_ptr));
      mem_ptr += sizeof(char);
      str_buf << tmp_char;
    }
Guolin Ke's avatar
Guolin Ke committed
420
    dataset->feature_names_.emplace_back(str_buf.str());
Guolin Ke's avatar
Guolin Ke committed
421
  }
422
423
424
425
426
427
428
429
430
431
432
433
434
  // get forced_bin_bounds_
  dataset->forced_bin_bounds_ = std::vector<std::vector<double>>(dataset->num_total_features_, std::vector<double>());
  for (int i = 0; i < dataset->num_total_features_; ++i) {
    int num_bounds = *(reinterpret_cast<const int*>(mem_ptr));
    mem_ptr += sizeof(int);
    dataset->forced_bin_bounds_[i] = std::vector<double>();
    const double* tmp_ptr_forced_bounds = reinterpret_cast<const double*>(mem_ptr);
    for (int j = 0; j < num_bounds; ++j) {
      double bound = tmp_ptr_forced_bounds[j];
      dataset->forced_bin_bounds_[i].push_back(bound);
    }
    mem_ptr += num_bounds * sizeof(double);
  }
Guolin Ke's avatar
Guolin Ke committed
435
436

  // read size of meta data
437
  read_cnt = reader->Read(buffer.data(), sizeof(size_t));
Guolin Ke's avatar
Guolin Ke committed
438

439
  if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
440
441
442
    Log::Fatal("Binary file error: meta data has the wrong size");
  }

Guolin Ke's avatar
Guolin Ke committed
443
  size_t size_of_metadata = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
444
445
446
447

  // re-allocate space if not enough
  if (size_of_metadata > buffer_size) {
    buffer_size = size_of_metadata;
Guolin Ke's avatar
Guolin Ke committed
448
    buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
449
450
  }
  //  read meta data
451
  read_cnt = reader->Read(buffer.data(), size_of_metadata);
Guolin Ke's avatar
Guolin Ke committed
452
453
454
455
456

  if (read_cnt != size_of_metadata) {
    Log::Fatal("Binary file error: meta data is incorrect");
  }
  // load meta data
Guolin Ke's avatar
Guolin Ke committed
457
  dataset->metadata_.LoadFromMemory(buffer.data());
Guolin Ke's avatar
Guolin Ke committed
458

459
460
  *num_global_data = dataset->num_data_;
  used_data_indices->clear();
Guolin Ke's avatar
Guolin Ke committed
461
  // sample local used data if need to partition
Guolin Ke's avatar
Guolin Ke committed
462
  if (num_machines > 1 && !config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
463
464
465
466
    const data_size_t* query_boundaries = dataset->metadata_.query_boundaries();
    if (query_boundaries == nullptr) {
      // if not contain query file, minimal sample unit is one record
      for (data_size_t i = 0; i < dataset->num_data_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
467
        if (random_.NextShort(0, num_machines) == rank) {
468
          used_data_indices->push_back(i);
Guolin Ke's avatar
Guolin Ke committed
469
470
471
472
473
474
475
476
477
        }
      }
    } else {
      // if contain query file, minimal sample unit is one query
      data_size_t num_queries = dataset->metadata_.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      for (data_size_t i = 0; i < dataset->num_data_; ++i) {
        if (qid >= num_queries) {
478
479
          Log::Fatal("Current query exceeds the range of the query file,\n"
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
480
481
482
483
        }
        if (i >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
484
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
485
486
487
488
489
            is_query_used = true;
          }
          ++qid;
        }
        if (is_query_used) {
490
          used_data_indices->push_back(i);
Guolin Ke's avatar
Guolin Ke committed
491
492
493
        }
      }
    }
494
    dataset->num_data_ = static_cast<data_size_t>((*used_data_indices).size());
Guolin Ke's avatar
Guolin Ke committed
495
  }
496
  dataset->metadata_.PartitionLabel(*used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
497
  // read feature data
Guolin Ke's avatar
Guolin Ke committed
498
  for (int i = 0; i < dataset->num_groups_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
499
    // read feature size
500
501
    read_cnt = reader->Read(buffer.data(), sizeof(size_t));
    if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
502
503
      Log::Fatal("Binary file error: feature %d has the wrong size", i);
    }
Guolin Ke's avatar
Guolin Ke committed
504
    size_t size_of_feature = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
505
506
507
    // re-allocate space if not enough
    if (size_of_feature > buffer_size) {
      buffer_size = size_of_feature;
Guolin Ke's avatar
Guolin Ke committed
508
      buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
509
510
    }

511
    read_cnt = reader->Read(buffer.data(), size_of_feature);
Guolin Ke's avatar
Guolin Ke committed
512
513
514
515

    if (read_cnt != size_of_feature) {
      Log::Fatal("Binary file error: feature %d is incorrect, read count: %d", i, read_cnt);
    }
Guolin Ke's avatar
Guolin Ke committed
516
    dataset->feature_groups_.emplace_back(std::unique_ptr<FeatureGroup>(
517
518
      new FeatureGroup(buffer.data(),
                       *num_global_data,
519
                       *used_data_indices)));
Guolin Ke's avatar
Guolin Ke committed
520
  }
Guolin Ke's avatar
Guolin Ke committed
521
  dataset->feature_groups_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
522
  dataset->is_finish_load_ = true;
Guolin Ke's avatar
Guolin Ke committed
523
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
524
525
}

526

527
528
529
Dataset* DatasetLoader::ConstructFromSampleData(double** sample_values,
                                                int** sample_indices, int num_col, const int* num_per_col,
                                                size_t total_sample_size, data_size_t num_data) {
530
531
532
533
534
  int num_total_features = num_col;
  if (Network::num_machines() > 1) {
    num_total_features = Network::GlobalSyncUpByMax(num_total_features);
  }
  std::vector<std::unique_ptr<BinMapper>> bin_mappers(num_total_features);
535
536
  // fill feature_names_ if not header
  if (feature_names_.empty()) {
537
    for (int i = 0; i < num_col; ++i) {
538
539
540
541
542
      std::stringstream str_buf;
      str_buf << "Column_" << i;
      feature_names_.push_back(str_buf.str());
    }
  }
Belinda Trotta's avatar
Belinda Trotta committed
543
  if (!config_.max_bin_by_feature.empty()) {
544
545
    CHECK_EQ(static_cast<size_t>(num_col), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
Belinda Trotta's avatar
Belinda Trotta committed
546
  }
547
548
549
550
551

  // get forced split
  std::string forced_bins_path = config_.forcedbins_filename;
  std::vector<std::vector<double>> forced_bin_bounds = DatasetLoader::GetForcedBins(forced_bins_path, num_col, categorical_features_);

Guolin Ke's avatar
Guolin Ke committed
552
  const data_size_t filter_cnt = static_cast<data_size_t>(
Guolin Ke's avatar
Guolin Ke committed
553
    static_cast<double>(config_.min_data_in_leaf * total_sample_size) / num_data);
554
555
556
557
558
559
560
561
562
563
564
565
566
  if (Network::num_machines() == 1) {
    // if only one machine, find bin locally
    OMP_INIT_EX();
    #pragma omp parallel for schedule(guided)
    for (int i = 0; i < num_col; ++i) {
      OMP_LOOP_EX_BEGIN();
      if (ignore_features_.count(i) > 0) {
        bin_mappers[i] = nullptr;
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(i)) {
        bin_type = BinType::CategoricalBin;
567
568
569
570
        bool feat_is_unconstrained = ((config_.monotone_constraints.size() == 0) || (config_.monotone_constraints[i] == 0));
        if (!feat_is_unconstrained) {
            Log::Fatal("The output cannot be monotone with respect to categorical features");
        }
571
572
      }
      bin_mappers[i].reset(new BinMapper());
Belinda Trotta's avatar
Belinda Trotta committed
573
574
      if (config_.max_bin_by_feature.empty()) {
        bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
575
                                config_.max_bin, config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter,
576
577
                                bin_type, config_.use_missing, config_.zero_as_missing,
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
578
579
      } else {
        bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
580
                                config_.max_bin_by_feature[i], config_.min_data_in_bin,
581
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
582
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
583
      }
584
585
586
587
588
589
590
591
592
593
594
595
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();
  } else {
    // if have multi-machines, need to find bin distributed
    // different machines will find bin for different features
    int num_machines = Network::num_machines();
    int rank = Network::rank();
    // start and len will store the process feature indices for different machines
    // machine i will find bins for features in [ start[i], start[i] + len[i] )
    std::vector<int> start(num_machines);
    std::vector<int> len(num_machines);
596
    int step = (num_total_features + num_machines - 1) / num_machines;
597
598
599
600
    if (step < 1) { step = 1; }

    start[0] = 0;
    for (int i = 0; i < num_machines - 1; ++i) {
601
      len[i] = std::min(step, num_total_features - start[i]);
602
603
      start[i + 1] = start[i] + len[i];
    }
604
    len[num_machines - 1] = num_total_features - start[num_machines - 1];
605
606
607
608
609
610
611
612
613
614
615
616
    OMP_INIT_EX();
    #pragma omp parallel for schedule(guided)
    for (int i = 0; i < len[rank]; ++i) {
      OMP_LOOP_EX_BEGIN();
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(start[rank] + i)) {
        bin_type = BinType::CategoricalBin;
      }
      bin_mappers[i].reset(new BinMapper());
617
618
619
      if (num_col <= start[rank] + i) {
        continue;
      }
Belinda Trotta's avatar
Belinda Trotta committed
620
      if (config_.max_bin_by_feature.empty()) {
621
622
        bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
                                total_sample_size, config_.max_bin, config_.min_data_in_bin,
623
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
624
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
625
      } else {
626
627
        bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
                                total_sample_size, config_.max_bin_by_feature[start[rank] + i],
628
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
629
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
630
      }
631
632
      OMP_LOOP_EX_END();
    }
Guolin Ke's avatar
Guolin Ke committed
633
    OMP_THROW_EX();
634
    comm_size_t self_buf_size = 0;
635
    for (int i = 0; i < len[rank]; ++i) {
636
637
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
638
      }
639
      self_buf_size += static_cast<comm_size_t>(bin_mappers[i]->SizesInByte());
Guolin Ke's avatar
Guolin Ke committed
640
    }
641
642
    std::vector<char> input_buffer(self_buf_size);
    auto cp_ptr = input_buffer.data();
643
644
645
646
    for (int i = 0; i < len[rank]; ++i) {
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
647
648
      bin_mappers[i]->CopyTo(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
649
650
651
      // free
      bin_mappers[i].reset(nullptr);
    }
652
653
654
655
    std::vector<comm_size_t> size_len = Network::GlobalArray(self_buf_size);
    std::vector<comm_size_t> size_start(num_machines, 0);
    for (int i = 1; i < num_machines; ++i) {
      size_start[i] = size_start[i - 1] + size_len[i - 1];
656
    }
657
658
    comm_size_t total_buffer_size = size_start[num_machines - 1] + size_len[num_machines - 1];
    std::vector<char> output_buffer(total_buffer_size);
659
    // gather global feature bin mappers
660
661
    Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), total_buffer_size);
    cp_ptr = output_buffer.data();
662
    // restore features bins from buffer
663
    for (int i = 0; i < num_total_features; ++i) {
664
665
666
667
668
      if (ignore_features_.count(i) > 0) {
        bin_mappers[i] = nullptr;
        continue;
      }
      bin_mappers[i].reset(new BinMapper());
669
670
      bin_mappers[i]->CopyFrom(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
671
    }
Guolin Ke's avatar
Guolin Ke committed
672
  }
Guolin Ke's avatar
Guolin Ke committed
673
  auto dataset = std::unique_ptr<Dataset>(new Dataset(num_data));
Guolin Ke's avatar
Guolin Ke committed
674
  dataset->Construct(&bin_mappers, num_total_features, forced_bin_bounds, sample_indices, sample_values, num_per_col, num_col, total_sample_size, config_);
675
  dataset->set_feature_names(feature_names_);
Guolin Ke's avatar
Guolin Ke committed
676
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
677
}
Guolin Ke's avatar
Guolin Ke committed
678
679
680
681
682
683


// ---- private functions ----

void DatasetLoader::CheckDataset(const Dataset* dataset) {
  if (dataset->num_data_ <= 0) {
Guolin Ke's avatar
Guolin Ke committed
684
    Log::Fatal("Data file %s is empty", dataset->data_filename_.c_str());
Guolin Ke's avatar
Guolin Ke committed
685
  }
686
687
  if (dataset->feature_names_.size() != static_cast<size_t>(dataset->num_total_features_)) {
    Log::Fatal("Size of feature name error, should be %d, got %d", dataset->num_total_features_,
688
               static_cast<int>(dataset->feature_names_.size()));
689
  }
Guolin Ke's avatar
Guolin Ke committed
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
  bool is_feature_order_by_group = true;
  int last_group = -1;
  int last_sub_feature = -1;
  // if features are ordered, not need to use hist_buf
  for (int i = 0; i < dataset->num_features_; ++i) {
    int group = dataset->feature2group_[i];
    int sub_feature = dataset->feature2subfeature_[i];
    if (group < last_group) {
      is_feature_order_by_group = false;
    } else if (group == last_group) {
      if (sub_feature <= last_sub_feature) {
        is_feature_order_by_group = false;
        break;
      }
    }
    last_group = group;
    last_sub_feature = sub_feature;
  }
  if (!is_feature_order_by_group) {
709
    Log::Fatal("Features in dataset should be ordered by group");
Guolin Ke's avatar
Guolin Ke committed
710
  }
Guolin Ke's avatar
Guolin Ke committed
711
712
713
}

std::vector<std::string> DatasetLoader::LoadTextDataToMemory(const char* filename, const Metadata& metadata,
714
715
                                                             int rank, int num_machines, int* num_global_data,
                                                             std::vector<data_size_t>* used_data_indices) {
716
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
717
  used_data_indices->clear();
Guolin Ke's avatar
Guolin Ke committed
718
  if (num_machines == 1 || config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
719
720
721
722
723
724
725
726
727
    // read all lines
    *num_global_data = text_reader.ReadAllLines();
  } else {  // need partition data
            // get query data
    const data_size_t* query_boundaries = metadata.query_boundaries();

    if (query_boundaries == nullptr) {
      // if not contain query data, minimal sample unit is one record
      *num_global_data = text_reader.ReadAndFilterLines([this, rank, num_machines](data_size_t) {
Guolin Ke's avatar
Guolin Ke committed
728
        if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
729
730
731
732
733
734
735
736
737
738
739
740
741
742
          return true;
        } else {
          return false;
        }
      }, used_data_indices);
    } else {
      // if contain query data, minimal sample unit is one query
      data_size_t num_queries = metadata.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      *num_global_data = text_reader.ReadAndFilterLines(
        [this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries]
      (data_size_t line_idx) {
        if (qid >= num_queries) {
743
744
          Log::Fatal("Current query exceeds the range of the query file,\n"
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
745
746
747
748
        }
        if (line_idx >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
749
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
750
751
752
753
754
755
756
757
758
759
760
761
            is_query_used = true;
          }
          ++qid;
        }
        return is_query_used;
      }, used_data_indices);
    }
  }
  return std::move(text_reader.Lines());
}

std::vector<std::string> DatasetLoader::SampleTextDataFromMemory(const std::vector<std::string>& data) {
Guolin Ke's avatar
Guolin Ke committed
762
  int sample_cnt = config_.bin_construct_sample_cnt;
763
764
  if (static_cast<size_t>(sample_cnt) > data.size()) {
    sample_cnt = static_cast<int>(data.size());
765
  }
766
  auto sample_indices = random_.Sample(static_cast<int>(data.size()), sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
767
  std::vector<std::string> out(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
768
769
  for (size_t i = 0; i < sample_indices.size(); ++i) {
    const size_t idx = sample_indices[i];
Guolin Ke's avatar
Guolin Ke committed
770
    out[i] = data[idx];
Guolin Ke's avatar
Guolin Ke committed
771
772
773
774
  }
  return out;
}

775
776
777
std::vector<std::string> DatasetLoader::SampleTextDataFromFile(const char* filename, const Metadata& metadata,
                                                               int rank, int num_machines, int* num_global_data,
                                                               std::vector<data_size_t>* used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
778
  const data_size_t sample_cnt = static_cast<data_size_t>(config_.bin_construct_sample_cnt);
779
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
780
  std::vector<std::string> out_data;
Guolin Ke's avatar
Guolin Ke committed
781
  if (num_machines == 1 || config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
782
    *num_global_data = static_cast<data_size_t>(text_reader.SampleFromFile(&random_, sample_cnt, &out_data));
Guolin Ke's avatar
Guolin Ke committed
783
784
785
786
787
788
789
  } else {  // need partition data
            // get query data
    const data_size_t* query_boundaries = metadata.query_boundaries();
    if (query_boundaries == nullptr) {
      // if not contain query file, minimal sample unit is one record
      *num_global_data = text_reader.SampleAndFilterFromFile([this, rank, num_machines]
      (data_size_t) {
Guolin Ke's avatar
Guolin Ke committed
790
        if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
791
792
793
794
          return true;
        } else {
          return false;
        }
Guolin Ke's avatar
Guolin Ke committed
795
      }, used_data_indices, &random_, sample_cnt, &out_data);
Guolin Ke's avatar
Guolin Ke committed
796
797
798
799
800
801
802
803
804
    } else {
      // if contain query file, minimal sample unit is one query
      data_size_t num_queries = metadata.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      *num_global_data = text_reader.SampleAndFilterFromFile(
        [this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries]
      (data_size_t line_idx) {
        if (qid >= num_queries) {
805
806
          Log::Fatal("Query id exceeds the range of the query file, "
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
807
808
809
810
        }
        if (line_idx >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
811
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
812
813
814
815
816
            is_query_used = true;
          }
          ++qid;
        }
        return is_query_used;
Guolin Ke's avatar
Guolin Ke committed
817
      }, used_data_indices, &random_, sample_cnt, &out_data);
Guolin Ke's avatar
Guolin Ke committed
818
819
820
821
822
    }
  }
  return out_data;
}

823
824
825
void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
                                                    const std::vector<std::string>& sample_data,
                                                    const Parser* parser, Dataset* dataset) {
Guolin Ke's avatar
Guolin Ke committed
826
  std::vector<std::vector<double>> sample_values;
Guolin Ke's avatar
Guolin Ke committed
827
  std::vector<std::vector<int>> sample_indices;
Guolin Ke's avatar
Guolin Ke committed
828
829
  std::vector<std::pair<int, double>> oneline_features;
  double label;
Guolin Ke's avatar
Guolin Ke committed
830
  for (int i = 0; i < static_cast<int>(sample_data.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
831
832
833
834
    oneline_features.clear();
    // parse features
    parser->ParseOneLine(sample_data[i].c_str(), &oneline_features, &label);
    for (std::pair<int, double>& inner_data : oneline_features) {
835
      if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
Guolin Ke's avatar
Guolin Ke committed
836
837
        sample_values.resize(inner_data.first + 1);
        sample_indices.resize(inner_data.first + 1);
838
      }
Guolin Ke's avatar
Guolin Ke committed
839
      if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
Guolin Ke's avatar
Guolin Ke committed
840
841
        sample_values[inner_data.first].emplace_back(inner_data.second);
        sample_indices[inner_data.first].emplace_back(i);
Guolin Ke's avatar
Guolin Ke committed
842
843
844
845
      }
    }
  }

Guolin Ke's avatar
Guolin Ke committed
846
  dataset->feature_groups_.clear();
847
848
849
850
851
  dataset->num_total_features_ = std::max(static_cast<int>(sample_values.size()), parser->NumFeatures());
  if (num_machines > 1) {
    dataset->num_total_features_ = Network::GlobalSyncUpByMax(dataset->num_total_features_);
  }
  if (!feature_names_.empty()) {
852
    CHECK_EQ(dataset->num_total_features_, static_cast<int>(feature_names_.size()));
853
  }
Guolin Ke's avatar
Guolin Ke committed
854

Belinda Trotta's avatar
Belinda Trotta committed
855
  if (!config_.max_bin_by_feature.empty()) {
856
857
    CHECK_EQ(static_cast<size_t>(dataset->num_total_features_), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
Belinda Trotta's avatar
Belinda Trotta committed
858
859
  }

860
861
  // get forced split
  std::string forced_bins_path = config_.forcedbins_filename;
862
863
  std::vector<std::vector<double>> forced_bin_bounds = DatasetLoader::GetForcedBins(forced_bins_path,
                                                                                    dataset->num_total_features_,
864
865
                                                                                    categorical_features_);

Guolin Ke's avatar
Guolin Ke committed
866
867
868
869
870
871
  // check the range of label_idx, weight_idx and group_idx
  CHECK(label_idx_ >= 0 && label_idx_ <= dataset->num_total_features_);
  CHECK(weight_idx_ < 0 || weight_idx_ < dataset->num_total_features_);
  CHECK(group_idx_ < 0 || group_idx_ < dataset->num_total_features_);

  // fill feature_names_ if not header
Guolin Ke's avatar
Guolin Ke committed
872
  if (feature_names_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
873
874
875
876
877
878
    for (int i = 0; i < dataset->num_total_features_; ++i) {
      std::stringstream str_buf;
      str_buf << "Column_" << i;
      feature_names_.push_back(str_buf.str());
    }
  }
879
  dataset->set_feature_names(feature_names_);
Guolin Ke's avatar
Guolin Ke committed
880
  std::vector<std::unique_ptr<BinMapper>> bin_mappers(dataset->num_total_features_);
Guolin Ke's avatar
Guolin Ke committed
881
  const data_size_t filter_cnt = static_cast<data_size_t>(
Guolin Ke's avatar
Guolin Ke committed
882
    static_cast<double>(config_.min_data_in_leaf* sample_data.size()) / dataset->num_data_);
Guolin Ke's avatar
Guolin Ke committed
883
884
885
  // start find bins
  if (num_machines == 1) {
    // if only one machine, find bin locally
886
    OMP_INIT_EX();
887
    #pragma omp parallel for schedule(guided)
Guolin Ke's avatar
Guolin Ke committed
888
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
889
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
890
      if (ignore_features_.count(i) > 0) {
Guolin Ke's avatar
Guolin Ke committed
891
        bin_mappers[i] = nullptr;
Guolin Ke's avatar
Guolin Ke committed
892
893
        continue;
      }
894
895
896
897
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(i)) {
        bin_type = BinType::CategoricalBin;
      }
Guolin Ke's avatar
Guolin Ke committed
898
      bin_mappers[i].reset(new BinMapper());
Belinda Trotta's avatar
Belinda Trotta committed
899
900
      if (config_.max_bin_by_feature.empty()) {
        bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
901
                                sample_data.size(), config_.max_bin, config_.min_data_in_bin,
902
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
903
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
904
905
      } else {
        bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
906
                                sample_data.size(), config_.max_bin_by_feature[i],
907
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
908
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
909
      }
910
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
911
    }
912
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
913
914
  } else {
    // start and len will store the process feature indices for different machines
915
    // machine i will find bins for features in [ start[i], start[i] + len[i] )
Guolin Ke's avatar
Guolin Ke committed
916
917
    std::vector<int> start(num_machines);
    std::vector<int> len(num_machines);
918
    int step = (dataset->num_total_features_ + num_machines - 1) / num_machines;
Guolin Ke's avatar
Guolin Ke committed
919
920
921
922
    if (step < 1) { step = 1; }

    start[0] = 0;
    for (int i = 0; i < num_machines - 1; ++i) {
923
      len[i] = std::min(step, dataset->num_total_features_ - start[i]);
Guolin Ke's avatar
Guolin Ke committed
924
925
      start[i + 1] = start[i] + len[i];
    }
926
    len[num_machines - 1] = dataset->num_total_features_ - start[num_machines - 1];
927
    OMP_INIT_EX();
928
    #pragma omp parallel for schedule(guided)
929
    for (int i = 0; i < len[rank]; ++i) {
930
      OMP_LOOP_EX_BEGIN();
931
932
933
934
935
936
937
938
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(start[rank] + i)) {
        bin_type = BinType::CategoricalBin;
      }
      bin_mappers[i].reset(new BinMapper());
Nikita Titov's avatar
Nikita Titov committed
939
      if (static_cast<int>(sample_values.size()) <= start[rank] + i) {
940
941
        continue;
      }
Belinda Trotta's avatar
Belinda Trotta committed
942
      if (config_.max_bin_by_feature.empty()) {
943
        bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
Belinda Trotta's avatar
Belinda Trotta committed
944
                                static_cast<int>(sample_values[start[rank] + i].size()),
945
                                sample_data.size(), config_.max_bin, config_.min_data_in_bin,
946
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
947
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
948
      } else {
949
        bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
Belinda Trotta's avatar
Belinda Trotta committed
950
                                static_cast<int>(sample_values[start[rank] + i].size()),
951
                                sample_data.size(), config_.max_bin_by_feature[i],
952
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type,
953
                                config_.use_missing, config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
954
      }
955
      OMP_LOOP_EX_END();
956
    }
957
    OMP_THROW_EX();
958
    comm_size_t self_buf_size = 0;
Guolin Ke's avatar
Guolin Ke committed
959
    for (int i = 0; i < len[rank]; ++i) {
960
961
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
Guolin Ke's avatar
Guolin Ke committed
962
      }
963
      self_buf_size += static_cast<comm_size_t>(bin_mappers[i]->SizesInByte());
Guolin Ke's avatar
Guolin Ke committed
964
    }
965
966
    std::vector<char> input_buffer(self_buf_size);
    auto cp_ptr = input_buffer.data();
Guolin Ke's avatar
Guolin Ke committed
967
    for (int i = 0; i < len[rank]; ++i) {
Guolin Ke's avatar
Guolin Ke committed
968
969
970
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
971
972
      bin_mappers[i]->CopyTo(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
973
974
      // free
      bin_mappers[i].reset(nullptr);
Guolin Ke's avatar
Guolin Ke committed
975
    }
976
977
978
979
    std::vector<comm_size_t> size_len = Network::GlobalArray(self_buf_size);
    std::vector<comm_size_t> size_start(num_machines, 0);
    for (int i = 1; i < num_machines; ++i) {
      size_start[i] = size_start[i - 1] + size_len[i - 1];
Guolin Ke's avatar
Guolin Ke committed
980
    }
981
982
    comm_size_t total_buffer_size = size_start[num_machines - 1] + size_len[num_machines - 1];
    std::vector<char> output_buffer(total_buffer_size);
Guolin Ke's avatar
Guolin Ke committed
983
    // gather global feature bin mappers
984
985
    Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), total_buffer_size);
    cp_ptr = output_buffer.data();
Guolin Ke's avatar
Guolin Ke committed
986
    // restore features bins from buffer
987
    for (int i = 0; i < dataset->num_total_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
988
      if (ignore_features_.count(i) > 0) {
Guolin Ke's avatar
Guolin Ke committed
989
        bin_mappers[i] = nullptr;
Guolin Ke's avatar
Guolin Ke committed
990
991
        continue;
      }
Guolin Ke's avatar
Guolin Ke committed
992
      bin_mappers[i].reset(new BinMapper());
993
994
      bin_mappers[i]->CopyFrom(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
Guolin Ke's avatar
Guolin Ke committed
995
996
    }
  }
997
  dataset->Construct(&bin_mappers, dataset->num_total_features_, forced_bin_bounds, Common::Vector2Ptr<int>(&sample_indices).data(),
Guolin Ke's avatar
Guolin Ke committed
998
                     Common::Vector2Ptr<double>(&sample_values).data(),
999
                     Common::VectorSize<int>(sample_indices).data(), static_cast<int>(sample_indices.size()), sample_data.size(), config_);
Guolin Ke's avatar
Guolin Ke committed
1000
1001
1002
}

/*! \brief Extract local features from memory */
Guolin Ke's avatar
Guolin Ke committed
1003
void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>* text_data, const Parser* parser, Dataset* dataset) {
Guolin Ke's avatar
Guolin Ke committed
1004
1005
  std::vector<std::pair<int, double>> oneline_features;
  double tmp_label = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
1006
  auto& ref_text_data = *text_data;
Guolin Ke's avatar
Guolin Ke committed
1007
  if (predict_fun_ == nullptr) {
1008
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1009
    // if doesn't need to prediction with initial model
1010
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
Guolin Ke's avatar
Guolin Ke committed
1011
    for (data_size_t i = 0; i < dataset->num_data_; ++i) {
1012
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1013
1014
1015
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
Guolin Ke's avatar
Guolin Ke committed
1016
      parser->ParseOneLine(ref_text_data[i].c_str(), &oneline_features, &tmp_label);
Guolin Ke's avatar
Guolin Ke committed
1017
      // set label
1018
      dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1019
      // free processed line:
Guolin Ke's avatar
Guolin Ke committed
1020
      ref_text_data[i].clear();
Guolin Ke's avatar
Guolin Ke committed
1021
1022
      // shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
      // text_reader_->Lines()[i].shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
1023
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1024
1025
      // push data
      for (auto& inner_data : oneline_features) {
1026
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1027
1028
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1029
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1030
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1031
1032
1033
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
          dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
Guolin Ke's avatar
Guolin Ke committed
1034
1035
        } else {
          if (inner_data.first == weight_idx_) {
1036
            dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1037
1038
1039
1040
1041
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1042
      dataset->FinishOneRow(tid, i, is_feature_added);
1043
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1044
    }
1045
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1046
  } else {
1047
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1048
    // if need to prediction with initial model
1049
    std::vector<double> init_score(dataset->num_data_ * num_class_);
1050
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
Guolin Ke's avatar
Guolin Ke committed
1051
    for (data_size_t i = 0; i < dataset->num_data_; ++i) {
1052
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1053
1054
1055
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
Guolin Ke's avatar
Guolin Ke committed
1056
      parser->ParseOneLine(ref_text_data[i].c_str(), &oneline_features, &tmp_label);
Guolin Ke's avatar
Guolin Ke committed
1057
      // set initial score
Guolin Ke's avatar
Guolin Ke committed
1058
1059
      std::vector<double> oneline_init_score(num_class_);
      predict_fun_(oneline_features, oneline_init_score.data());
1060
      for (int k = 0; k < num_class_; ++k) {
1061
        init_score[k * dataset->num_data_ + i] = static_cast<double>(oneline_init_score[k]);
Guolin Ke's avatar
Guolin Ke committed
1062
1063
      }
      // set label
1064
      dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1065
1066
1067
1068
1069
      // free processed line:
      text_data[i].clear();
      // shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
      // text_reader_->Lines()[i].shrink_to_fit();
      // push data
Guolin Ke's avatar
Guolin Ke committed
1070
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1071
      for (auto& inner_data : oneline_features) {
1072
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1073
1074
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1075
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1076
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1077
1078
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
1079
          dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
Guolin Ke's avatar
Guolin Ke committed
1080
1081
        } else {
          if (inner_data.first == weight_idx_) {
1082
            dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1083
1084
1085
1086
1087
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1088
      dataset->FinishOneRow(tid, i, is_feature_added);
1089
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1090
    }
1091
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1092
    // metadata_ will manage space of init_score
1093
    dataset->metadata_.SetInitScore(init_score.data(), dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1094
  }
Guolin Ke's avatar
Guolin Ke committed
1095
  dataset->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1096
  // text data can be free after loaded feature values
Guolin Ke's avatar
Guolin Ke committed
1097
  text_data->clear();
Guolin Ke's avatar
Guolin Ke committed
1098
1099
1100
}

/*! \brief Extract local features from file */
1101
1102
void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser* parser,
                                            const std::vector<data_size_t>& used_data_indices, Dataset* dataset) {
1103
  std::vector<double> init_score;
Guolin Ke's avatar
Guolin Ke committed
1104
  if (predict_fun_ != nullptr) {
1105
    init_score = std::vector<double>(dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1106
1107
1108
1109
1110
1111
  }
  std::function<void(data_size_t, const std::vector<std::string>&)> process_fun =
    [this, &init_score, &parser, &dataset]
  (data_size_t start_idx, const std::vector<std::string>& lines) {
    std::vector<std::pair<int, double>> oneline_features;
    double tmp_label = 0.0f;
1112
    OMP_INIT_EX();
1113
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
Guolin Ke's avatar
Guolin Ke committed
1114
    for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) {
1115
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1116
1117
1118
1119
1120
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
      parser->ParseOneLine(lines[i].c_str(), &oneline_features, &tmp_label);
      // set initial score
Guolin Ke's avatar
Guolin Ke committed
1121
      if (!init_score.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1122
1123
        std::vector<double> oneline_init_score(num_class_);
        predict_fun_(oneline_features, oneline_init_score.data());
1124
        for (int k = 0; k < num_class_; ++k) {
1125
          init_score[k * dataset->num_data_ + start_idx + i] = static_cast<double>(oneline_init_score[k]);
Guolin Ke's avatar
Guolin Ke committed
1126
1127
1128
        }
      }
      // set label
1129
      dataset->metadata_.SetLabelAt(start_idx + i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1130
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1131
1132
      // push data
      for (auto& inner_data : oneline_features) {
1133
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1134
1135
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1136
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1137
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1138
1139
1140
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
          dataset->feature_groups_[group]->PushData(tid, sub_feature, start_idx + i, inner_data.second);
Guolin Ke's avatar
Guolin Ke committed
1141
1142
        } else {
          if (inner_data.first == weight_idx_) {
1143
            dataset->metadata_.SetWeightAt(start_idx + i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1144
1145
1146
1147
1148
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(start_idx + i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1149
      dataset->FinishOneRow(tid, i, is_feature_added);
1150
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1151
    }
1152
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1153
  };
1154
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
1155
  if (!used_data_indices.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1156
1157
1158
1159
1160
1161
1162
1163
    // only need part of data
    text_reader.ReadPartAndProcessParallel(used_data_indices, process_fun);
  } else {
    // need full data
    text_reader.ReadAllAndProcessParallel(process_fun);
  }

  // metadata_ will manage space of init_score
Guolin Ke's avatar
Guolin Ke committed
1164
  if (!init_score.empty()) {
1165
    dataset->metadata_.SetInitScore(init_score.data(), dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1166
  }
Guolin Ke's avatar
Guolin Ke committed
1167
  dataset->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1168
1169
1170
}

/*! \brief Check can load from binary file */
1171
std::string DatasetLoader::CheckCanLoadFromBin(const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
1172
1173
1174
  std::string bin_filename(filename);
  bin_filename.append(".bin");

1175
  auto reader = VirtualFileReader::Make(bin_filename.c_str());
Guolin Ke's avatar
Guolin Ke committed
1176

1177
  if (!reader->Init()) {
1178
    bin_filename = std::string(filename);
1179
1180
    reader = VirtualFileReader::Make(bin_filename.c_str());
    if (!reader->Init()) {
1181
      Log::Fatal("Cannot open data file %s", bin_filename.c_str());
1182
    }
1183
  }
1184
1185
1186
1187
1188

  size_t buffer_size = 256;
  auto buffer = std::vector<char>(buffer_size);
  // read size of token
  size_t size_of_token = std::strlen(Dataset::binary_file_token);
1189
  size_t read_cnt = reader->Read(buffer.data(), size_of_token);
1190
1191
  if (read_cnt == size_of_token
      && std::string(buffer.data()) == std::string(Dataset::binary_file_token)) {
1192
    return bin_filename;
Guolin Ke's avatar
Guolin Ke committed
1193
  } else {
1194
    return std::string();
Guolin Ke's avatar
Guolin Ke committed
1195
1196
1197
  }
}

1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210


std::vector<std::vector<double>> DatasetLoader::GetForcedBins(std::string forced_bins_path, int num_total_features,
                                                              const std::unordered_set<int>& categorical_features) {
  std::vector<std::vector<double>> forced_bins(num_total_features, std::vector<double>());
  if (forced_bins_path != "") {
    std::ifstream forced_bins_stream(forced_bins_path.c_str());
    if (forced_bins_stream.fail()) {
      Log::Warning("Could not open %s. Will ignore.", forced_bins_path.c_str());
    } else {
      std::stringstream buffer;
      buffer << forced_bins_stream.rdbuf();
      std::string err;
Guolin Ke's avatar
Guolin Ke committed
1211
      Json forced_bins_json = Json::parse(buffer.str(), &err);
1212
1213
1214
1215
      CHECK(forced_bins_json.is_array());
      std::vector<Json> forced_bins_arr = forced_bins_json.array_items();
      for (size_t i = 0; i < forced_bins_arr.size(); ++i) {
        int feature_num = forced_bins_arr[i]["feature"].int_value();
Nikita Titov's avatar
Nikita Titov committed
1216
        CHECK_LT(feature_num, num_total_features);
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
        if (categorical_features.count(feature_num)) {
          Log::Warning("Feature %d is categorical. Will ignore forced bins for this  feature.", feature_num);
        } else {
          std::vector<Json> bounds_arr = forced_bins_arr[i]["bin_upper_bound"].array_items();
          for (size_t j = 0; j < bounds_arr.size(); ++j) {
            forced_bins[feature_num].push_back(bounds_arr[j].number_value());
          }
        }
      }
      // remove duplicates
      for (int i = 0; i < num_total_features; ++i) {
        auto new_end = std::unique(forced_bins[i].begin(), forced_bins[i].end());
        forced_bins[i].erase(new_end, forced_bins[i].end());
      }
    }
  }
  return forced_bins;
}

1236
}  // namespace LightGBM