dataset_loader.cpp 52 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
#include <LightGBM/dataset_loader.h>

Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/network.h>
8
9
10
#include <LightGBM/utils/array_args.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
11

12
13
#include <fstream>

14
15
#include <LightGBM/json11.hpp>

Guolin Ke's avatar
Guolin Ke committed
16
17
namespace LightGBM {

18
19
using json11::Json;

Guolin Ke's avatar
Guolin Ke committed
20
21
DatasetLoader::DatasetLoader(const Config& io_config, const PredictFunction& predict_fun, int num_class, const char* filename)
  :config_(io_config), random_(config_.data_random_seed), predict_fun_(predict_fun), num_class_(num_class) {
Guolin Ke's avatar
Guolin Ke committed
22
23
24
25
  label_idx_ = 0;
  weight_idx_ = NO_SPECIFIC;
  group_idx_ = NO_SPECIFIC;
  SetHeader(filename);
Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
30
}

DatasetLoader::~DatasetLoader() {
}

Guolin Ke's avatar
Guolin Ke committed
31
void DatasetLoader::SetHeader(const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
32
  std::unordered_map<std::string, int> name2idx;
Guolin Ke's avatar
Guolin Ke committed
33
34
  std::string name_prefix("name:");
  if (filename != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
35
    TextReader<data_size_t> text_reader(filename, config_.header);
Guolin Ke's avatar
Guolin Ke committed
36

Guolin Ke's avatar
Guolin Ke committed
37
    // get column names
Guolin Ke's avatar
Guolin Ke committed
38
    if (config_.header) {
Guolin Ke's avatar
Guolin Ke committed
39
      std::string first_line = text_reader.first_line();
40
      feature_names_ = Common::Split(first_line.c_str(), "\t,");
Guolin Ke's avatar
Guolin Ke committed
41
42
    }

Guolin Ke's avatar
Guolin Ke committed
43
    // load label idx first
Guolin Ke's avatar
Guolin Ke committed
44
45
46
    if (config_.label_column.size() > 0) {
      if (Common::StartsWith(config_.label_column, name_prefix)) {
        std::string name = config_.label_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
47
48
49
50
51
52
53
54
55
56
        label_idx_ = -1;
        for (int i = 0; i < static_cast<int>(feature_names_.size()); ++i) {
          if (name == feature_names_[i]) {
            label_idx_ = i;
            break;
          }
        }
        if (label_idx_ >= 0) {
          Log::Info("Using column %s as label", name.c_str());
        } else {
57
58
          Log::Fatal("Could not find label column %s in data file \n"
                     "or data file doesn't contain header", name.c_str());
Guolin Ke's avatar
Guolin Ke committed
59
        }
Guolin Ke's avatar
Guolin Ke committed
60
      } else {
Guolin Ke's avatar
Guolin Ke committed
61
        if (!Common::AtoiAndCheck(config_.label_column.c_str(), &label_idx_)) {
62
63
64
          Log::Fatal("label_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
65
66
        }
        Log::Info("Using column number %d as label", label_idx_);
Guolin Ke's avatar
Guolin Ke committed
67
68
      }
    }
Guolin Ke's avatar
Guolin Ke committed
69

Guolin Ke's avatar
Guolin Ke committed
70
    if (!feature_names_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
71
72
73
74
      // erase label column name
      feature_names_.erase(feature_names_.begin() + label_idx_);
      for (size_t i = 0; i < feature_names_.size(); ++i) {
        name2idx[feature_names_[i]] = static_cast<int>(i);
Guolin Ke's avatar
Guolin Ke committed
75
      }
Guolin Ke's avatar
Guolin Ke committed
76
77
78
    }

    // load ignore columns
Guolin Ke's avatar
Guolin Ke committed
79
80
81
    if (config_.ignore_column.size() > 0) {
      if (Common::StartsWith(config_.ignore_column, name_prefix)) {
        std::string names = config_.ignore_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
82
83
84
85
86
87
88
89
90
        for (auto name : Common::Split(names.c_str(), ',')) {
          if (name2idx.count(name) > 0) {
            int tmp = name2idx[name];
            ignore_features_.emplace(tmp);
          } else {
            Log::Fatal("Could not find ignore column %s in data file", name.c_str());
          }
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
91
        for (auto token : Common::Split(config_.ignore_column.c_str(), ',')) {
Guolin Ke's avatar
Guolin Ke committed
92
93
          int tmp = 0;
          if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
94
95
96
            Log::Fatal("ignore_column is not a number,\n"
                       "if you want to use a column name,\n"
                       "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
97
98
          }
          ignore_features_.emplace(tmp);
Guolin Ke's avatar
Guolin Ke committed
99
100
101
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
102
    // load weight idx
Guolin Ke's avatar
Guolin Ke committed
103
104
105
    if (config_.weight_column.size() > 0) {
      if (Common::StartsWith(config_.weight_column, name_prefix)) {
        std::string name = config_.weight_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
106
107
108
109
110
111
        if (name2idx.count(name) > 0) {
          weight_idx_ = name2idx[name];
          Log::Info("Using column %s as weight", name.c_str());
        } else {
          Log::Fatal("Could not find weight column %s in data file", name.c_str());
        }
Guolin Ke's avatar
Guolin Ke committed
112
      } else {
Guolin Ke's avatar
Guolin Ke committed
113
        if (!Common::AtoiAndCheck(config_.weight_column.c_str(), &weight_idx_)) {
114
115
116
          Log::Fatal("weight_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
117
118
        }
        Log::Info("Using column number %d as weight", weight_idx_);
Guolin Ke's avatar
Guolin Ke committed
119
      }
Guolin Ke's avatar
Guolin Ke committed
120
      ignore_features_.emplace(weight_idx_);
Guolin Ke's avatar
Guolin Ke committed
121
    }
Guolin Ke's avatar
Guolin Ke committed
122
    // load group idx
Guolin Ke's avatar
Guolin Ke committed
123
124
125
    if (config_.group_column.size() > 0) {
      if (Common::StartsWith(config_.group_column, name_prefix)) {
        std::string name = config_.group_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
126
127
128
129
130
131
132
        if (name2idx.count(name) > 0) {
          group_idx_ = name2idx[name];
          Log::Info("Using column %s as group/query id", name.c_str());
        } else {
          Log::Fatal("Could not find group/query column %s in data file", name.c_str());
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
133
        if (!Common::AtoiAndCheck(config_.group_column.c_str(), &group_idx_)) {
134
135
136
          Log::Fatal("group_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
137
138
139
140
        }
        Log::Info("Using column number %d as group/query id", group_idx_);
      }
      ignore_features_.emplace(group_idx_);
Guolin Ke's avatar
Guolin Ke committed
141
142
    }
  }
Guolin Ke's avatar
Guolin Ke committed
143
144
145
  if (config_.categorical_feature.size() > 0) {
    if (Common::StartsWith(config_.categorical_feature, name_prefix)) {
      std::string names = config_.categorical_feature.substr(name_prefix.size());
146
147
148
149
150
      for (auto name : Common::Split(names.c_str(), ',')) {
        if (name2idx.count(name) > 0) {
          int tmp = name2idx[name];
          categorical_features_.emplace(tmp);
        } else {
Guolin Ke's avatar
Guolin Ke committed
151
          Log::Fatal("Could not find categorical_feature %s in data file", name.c_str());
152
153
154
        }
      }
    } else {
Guolin Ke's avatar
Guolin Ke committed
155
      for (auto token : Common::Split(config_.categorical_feature.c_str(), ',')) {
156
157
        int tmp = 0;
        if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
Guolin Ke's avatar
Guolin Ke committed
158
          Log::Fatal("categorical_feature is not a number,\n"
159
160
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
161
162
163
164
165
        }
        categorical_features_.emplace(tmp);
      }
    }
  }
Guolin Ke's avatar
Guolin Ke committed
166
167
}

168
Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_machines) {
Guolin Ke's avatar
Guolin Ke committed
169
  // don't support query id in data file when training in parallel
Guolin Ke's avatar
Guolin Ke committed
170
  if (num_machines > 1 && !config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
171
    if (group_idx_ > 0) {
172
173
      Log::Fatal("Using a query id without pre-partitioning the data file is not supported for parallel training.\n"
                 "Please use an additional query file or pre-partition the data");
Guolin Ke's avatar
Guolin Ke committed
174
175
    }
  }
Guolin Ke's avatar
Guolin Ke committed
176
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
Guolin Ke's avatar
Guolin Ke committed
177
178
  data_size_t num_global_data = 0;
  std::vector<data_size_t> used_data_indices;
179
180
  auto bin_filename = CheckCanLoadFromBin(filename);
  if (bin_filename.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
181
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_));
Guolin Ke's avatar
Guolin Ke committed
182
183
184
185
    if (parser == nullptr) {
      Log::Fatal("Could not recognize data format of %s", filename);
    }
    dataset->data_filename_ = filename;
Guolin Ke's avatar
Guolin Ke committed
186
    dataset->label_idx_ = label_idx_;
187
    dataset->metadata_.Init(filename);
Guolin Ke's avatar
Guolin Ke committed
188
    if (!config_.two_round) {
Guolin Ke's avatar
Guolin Ke committed
189
      // read data to memory
190
      auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
191
192
193
194
      dataset->num_data_ = static_cast<data_size_t>(text_data.size());
      // sample data
      auto sample_data = SampleTextDataFromMemory(text_data);
      // construct feature bin mappers
Guolin Ke's avatar
Guolin Ke committed
195
      ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
196
      // initialize label
197
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
198
      // extract features
Guolin Ke's avatar
Guolin Ke committed
199
      ExtractFeaturesFromMemory(&text_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
200
201
202
203
204
205
206
207
208
209
      text_data.clear();
    } else {
      // sample data from file
      auto sample_data = SampleTextDataFromFile(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
      if (used_data_indices.size() > 0) {
        dataset->num_data_ = static_cast<data_size_t>(used_data_indices.size());
      } else {
        dataset->num_data_ = num_global_data;
      }
      // construct feature bin mappers
Guolin Ke's avatar
Guolin Ke committed
210
      ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
211
      // initialize label
212
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
213
      Log::Debug("Making second pass...");
Guolin Ke's avatar
Guolin Ke committed
214
      // extract features
Guolin Ke's avatar
Guolin Ke committed
215
      ExtractFeaturesFromFile(filename, parser.get(), used_data_indices, dataset.get());
Guolin Ke's avatar
Guolin Ke committed
216
217
218
    }
  } else {
    // load data from binary file
219
    dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), rank, num_machines, &num_global_data, &used_data_indices));
Guolin Ke's avatar
Guolin Ke committed
220
221
222
223
  }
  // check meta data
  dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);
  // need to check training data
Guolin Ke's avatar
Guolin Ke committed
224
225
  CheckDataset(dataset.get());
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
226
227
228
229
}



230
Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, const Dataset* train_data) {
Guolin Ke's avatar
Guolin Ke committed
231
232
  data_size_t num_global_data = 0;
  std::vector<data_size_t> used_data_indices;
Guolin Ke's avatar
Guolin Ke committed
233
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
234
235
  auto bin_filename = CheckCanLoadFromBin(filename);
  if (bin_filename.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
236
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_));
Guolin Ke's avatar
Guolin Ke committed
237
238
239
240
    if (parser == nullptr) {
      Log::Fatal("Could not recognize data format of %s", filename);
    }
    dataset->data_filename_ = filename;
Guolin Ke's avatar
Guolin Ke committed
241
    dataset->label_idx_ = label_idx_;
242
    dataset->metadata_.Init(filename);
Guolin Ke's avatar
Guolin Ke committed
243
    if (!config_.two_round) {
Guolin Ke's avatar
Guolin Ke committed
244
245
246
247
      // read data in memory
      auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, 0, 1, &num_global_data, &used_data_indices);
      dataset->num_data_ = static_cast<data_size_t>(text_data.size());
      // initialize label
248
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
249
      dataset->CreateValid(train_data);
Guolin Ke's avatar
Guolin Ke committed
250
      // extract features
Guolin Ke's avatar
Guolin Ke committed
251
      ExtractFeaturesFromMemory(&text_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
252
253
      text_data.clear();
    } else {
Guolin Ke's avatar
Guolin Ke committed
254
      TextReader<data_size_t> text_reader(filename, config_.header);
Guolin Ke's avatar
Guolin Ke committed
255
256
257
258
      // Get number of lines of data file
      dataset->num_data_ = static_cast<data_size_t>(text_reader.CountLine());
      num_global_data = dataset->num_data_;
      // initialize label
259
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
260
      dataset->CreateValid(train_data);
Guolin Ke's avatar
Guolin Ke committed
261
      // extract features
Guolin Ke's avatar
Guolin Ke committed
262
      ExtractFeaturesFromFile(filename, parser.get(), used_data_indices, dataset.get());
Guolin Ke's avatar
Guolin Ke committed
263
264
265
    }
  } else {
    // load data from binary file
266
    dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), 0, 1, &num_global_data, &used_data_indices));
Guolin Ke's avatar
Guolin Ke committed
267
268
269
270
  }
  // not need to check validation data
  // check meta data
  dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
271
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
272
273
}

274
275
276
Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* bin_filename,
                                        int rank, int num_machines, int* num_global_data,
                                        std::vector<data_size_t>* used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
277
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
278
  auto reader = VirtualFileReader::Make(bin_filename);
Guolin Ke's avatar
Guolin Ke committed
279
  dataset->data_filename_ = data_filename;
280
  if (!reader->Init()) {
Guolin Ke's avatar
Guolin Ke committed
281
282
283
284
285
    Log::Fatal("Could not read binary data from %s", bin_filename);
  }

  // buffer to read binary file
  size_t buffer_size = 16 * 1024 * 1024;
Guolin Ke's avatar
Guolin Ke committed
286
  auto buffer = std::vector<char>(buffer_size);
287

288
289
  // check token
  size_t size_of_token = std::strlen(Dataset::binary_file_token);
290
291
  size_t read_cnt = reader->Read(buffer.data(), sizeof(char) * size_of_token);
  if (read_cnt != sizeof(char) * size_of_token) {
292
293
294
    Log::Fatal("Binary file error: token has the wrong size");
  }
  if (std::string(buffer.data()) != std::string(Dataset::binary_file_token)) {
295
    Log::Fatal("Input file is not LightGBM binary file");
296
  }
Guolin Ke's avatar
Guolin Ke committed
297
298

  // read size of header
299
  read_cnt = reader->Read(buffer.data(), sizeof(size_t));
Guolin Ke's avatar
Guolin Ke committed
300

301
  if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
302
303
304
    Log::Fatal("Binary file error: header has the wrong size");
  }

Guolin Ke's avatar
Guolin Ke committed
305
  size_t size_of_head = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
306
307
308
309

  // re-allocmate space if not enough
  if (size_of_head > buffer_size) {
    buffer_size = size_of_head;
Guolin Ke's avatar
Guolin Ke committed
310
    buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
311
312
  }
  // read header
313
  read_cnt = reader->Read(buffer.data(), size_of_head);
Guolin Ke's avatar
Guolin Ke committed
314
315
316
317
318

  if (read_cnt != size_of_head) {
    Log::Fatal("Binary file error: header is incorrect");
  }
  // get header
Guolin Ke's avatar
Guolin Ke committed
319
  const char* mem_ptr = buffer.data();
Guolin Ke's avatar
Guolin Ke committed
320
321
322
323
324
325
  dataset->num_data_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
  mem_ptr += sizeof(dataset->num_data_);
  dataset->num_features_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->num_features_);
  dataset->num_total_features_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->num_total_features_);
Guolin Ke's avatar
Guolin Ke committed
326
327
  dataset->label_idx_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->label_idx_);
328
329
330
331
332
333
334
335
336
337
  dataset->max_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->max_bin_);
  dataset->bin_construct_sample_cnt_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->bin_construct_sample_cnt_);
  dataset->min_data_in_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->min_data_in_bin_);
  dataset->use_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
  mem_ptr += sizeof(dataset->use_missing_);
  dataset->zero_as_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
  mem_ptr += sizeof(dataset->zero_as_missing_);
Guolin Ke's avatar
Guolin Ke committed
338
339
  const int* tmp_feature_map = reinterpret_cast<const int*>(mem_ptr);
  dataset->used_feature_map_.clear();
Guolin Ke's avatar
Guolin Ke committed
340
  for (int i = 0; i < dataset->num_total_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
341
342
    dataset->used_feature_map_.push_back(tmp_feature_map[i]);
  }
Guolin Ke's avatar
Guolin Ke committed
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
  mem_ptr += sizeof(int) * dataset->num_total_features_;
  // num_groups
  dataset->num_groups_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->num_groups_);
  // real_feature_idx_
  const int* tmp_ptr_real_feature_idx_ = reinterpret_cast<const int*>(mem_ptr);
  dataset->real_feature_idx_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->real_feature_idx_.push_back(tmp_ptr_real_feature_idx_[i]);
  }
  mem_ptr += sizeof(int) * dataset->num_features_;
  // feature2group
  const int* tmp_ptr_feature2group = reinterpret_cast<const int*>(mem_ptr);
  dataset->feature2group_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->feature2group_.push_back(tmp_ptr_feature2group[i]);
  }
  mem_ptr += sizeof(int) * dataset->num_features_;
  // feature2subfeature
  const int* tmp_ptr_feature2subfeature = reinterpret_cast<const int*>(mem_ptr);
  dataset->feature2subfeature_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->feature2subfeature_.push_back(tmp_ptr_feature2subfeature[i]);
  }
  mem_ptr += sizeof(int) * dataset->num_features_;
  // group_bin_boundaries
  const uint64_t* tmp_ptr_group_bin_boundaries = reinterpret_cast<const uint64_t*>(mem_ptr);
  dataset->group_bin_boundaries_.clear();
  for (int i = 0; i < dataset->num_groups_ + 1; ++i) {
    dataset->group_bin_boundaries_.push_back(tmp_ptr_group_bin_boundaries[i]);
  }
  mem_ptr += sizeof(uint64_t) * (dataset->num_groups_ + 1);

  // group_feature_start_
  const int* tmp_ptr_group_feature_start = reinterpret_cast<const int*>(mem_ptr);
  dataset->group_feature_start_.clear();
379
  for (int i = 0; i < dataset->num_groups_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
380
381
382
383
384
385
386
387
388
389
390
391
    dataset->group_feature_start_.push_back(tmp_ptr_group_feature_start[i]);
  }
  mem_ptr += sizeof(int) * (dataset->num_groups_);

  // group_feature_cnt_
  const int* tmp_ptr_group_feature_cnt = reinterpret_cast<const int*>(mem_ptr);
  dataset->group_feature_cnt_.clear();
  for (int i = 0; i < dataset->num_groups_; ++i) {
    dataset->group_feature_cnt_.push_back(tmp_ptr_group_feature_cnt[i]);
  }
  mem_ptr += sizeof(int) * (dataset->num_groups_);

Belinda Trotta's avatar
Belinda Trotta committed
392
  if (!config_.max_bin_by_feature.empty()) {
393
394
    CHECK_EQ(static_cast<size_t>(dataset->num_total_features_), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
Belinda Trotta's avatar
Belinda Trotta committed
395
396
397
398
399
400
401
402
403
404
405
406
407
408
    dataset->max_bin_by_feature_.resize(dataset->num_total_features_);
    dataset->max_bin_by_feature_.assign(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end());
  } else {
    const int32_t* tmp_ptr_max_bin_by_feature = reinterpret_cast<const int32_t*>(mem_ptr);
    dataset->max_bin_by_feature_.clear();
    for (int i = 0; i < dataset->num_total_features_; ++i) {
      dataset->max_bin_by_feature_.push_back(tmp_ptr_max_bin_by_feature[i]);
    }
  }
  mem_ptr += sizeof(int32_t) * (dataset->num_total_features_);
  if (ArrayArgs<int32_t>::CheckAll(dataset->max_bin_by_feature_, -1)) {
    dataset->max_bin_by_feature_.clear();
  }

Guolin Ke's avatar
Guolin Ke committed
409
  // get feature names
Guolin Ke's avatar
Guolin Ke committed
410
  dataset->feature_names_.clear();
Guolin Ke's avatar
Guolin Ke committed
411
412
413
414
415
416
417
418
419
420
  // write feature names
  for (int i = 0; i < dataset->num_total_features_; ++i) {
    int str_len = *(reinterpret_cast<const int*>(mem_ptr));
    mem_ptr += sizeof(int);
    std::stringstream str_buf;
    for (int j = 0; j < str_len; ++j) {
      char tmp_char = *(reinterpret_cast<const char*>(mem_ptr));
      mem_ptr += sizeof(char);
      str_buf << tmp_char;
    }
Guolin Ke's avatar
Guolin Ke committed
421
    dataset->feature_names_.emplace_back(str_buf.str());
Guolin Ke's avatar
Guolin Ke committed
422
  }
423
424
425
426
427
428
429
430
431
432
433
434
435
  // get forced_bin_bounds_
  dataset->forced_bin_bounds_ = std::vector<std::vector<double>>(dataset->num_total_features_, std::vector<double>());
  for (int i = 0; i < dataset->num_total_features_; ++i) {
    int num_bounds = *(reinterpret_cast<const int*>(mem_ptr));
    mem_ptr += sizeof(int);
    dataset->forced_bin_bounds_[i] = std::vector<double>();
    const double* tmp_ptr_forced_bounds = reinterpret_cast<const double*>(mem_ptr);
    for (int j = 0; j < num_bounds; ++j) {
      double bound = tmp_ptr_forced_bounds[j];
      dataset->forced_bin_bounds_[i].push_back(bound);
    }
    mem_ptr += num_bounds * sizeof(double);
  }
Guolin Ke's avatar
Guolin Ke committed
436
437

  // read size of meta data
438
  read_cnt = reader->Read(buffer.data(), sizeof(size_t));
Guolin Ke's avatar
Guolin Ke committed
439

440
  if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
441
442
443
    Log::Fatal("Binary file error: meta data has the wrong size");
  }

Guolin Ke's avatar
Guolin Ke committed
444
  size_t size_of_metadata = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
445
446
447
448

  // re-allocate space if not enough
  if (size_of_metadata > buffer_size) {
    buffer_size = size_of_metadata;
Guolin Ke's avatar
Guolin Ke committed
449
    buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
450
451
  }
  //  read meta data
452
  read_cnt = reader->Read(buffer.data(), size_of_metadata);
Guolin Ke's avatar
Guolin Ke committed
453
454
455
456
457

  if (read_cnt != size_of_metadata) {
    Log::Fatal("Binary file error: meta data is incorrect");
  }
  // load meta data
Guolin Ke's avatar
Guolin Ke committed
458
  dataset->metadata_.LoadFromMemory(buffer.data());
Guolin Ke's avatar
Guolin Ke committed
459

460
461
  *num_global_data = dataset->num_data_;
  used_data_indices->clear();
Guolin Ke's avatar
Guolin Ke committed
462
  // sample local used data if need to partition
Guolin Ke's avatar
Guolin Ke committed
463
  if (num_machines > 1 && !config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
464
465
466
467
    const data_size_t* query_boundaries = dataset->metadata_.query_boundaries();
    if (query_boundaries == nullptr) {
      // if not contain query file, minimal sample unit is one record
      for (data_size_t i = 0; i < dataset->num_data_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
468
        if (random_.NextShort(0, num_machines) == rank) {
469
          used_data_indices->push_back(i);
Guolin Ke's avatar
Guolin Ke committed
470
471
472
473
474
475
476
477
478
        }
      }
    } else {
      // if contain query file, minimal sample unit is one query
      data_size_t num_queries = dataset->metadata_.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      for (data_size_t i = 0; i < dataset->num_data_; ++i) {
        if (qid >= num_queries) {
479
480
          Log::Fatal("Current query exceeds the range of the query file,\n"
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
481
482
483
484
        }
        if (i >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
485
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
486
487
488
489
490
            is_query_used = true;
          }
          ++qid;
        }
        if (is_query_used) {
491
          used_data_indices->push_back(i);
Guolin Ke's avatar
Guolin Ke committed
492
493
494
        }
      }
    }
495
    dataset->num_data_ = static_cast<data_size_t>((*used_data_indices).size());
Guolin Ke's avatar
Guolin Ke committed
496
  }
497
  dataset->metadata_.PartitionLabel(*used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
498
  // read feature data
Guolin Ke's avatar
Guolin Ke committed
499
  for (int i = 0; i < dataset->num_groups_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
500
    // read feature size
501
502
    read_cnt = reader->Read(buffer.data(), sizeof(size_t));
    if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
503
504
      Log::Fatal("Binary file error: feature %d has the wrong size", i);
    }
Guolin Ke's avatar
Guolin Ke committed
505
    size_t size_of_feature = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
506
507
508
    // re-allocate space if not enough
    if (size_of_feature > buffer_size) {
      buffer_size = size_of_feature;
Guolin Ke's avatar
Guolin Ke committed
509
      buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
510
511
    }

512
    read_cnt = reader->Read(buffer.data(), size_of_feature);
Guolin Ke's avatar
Guolin Ke committed
513
514
515
516

    if (read_cnt != size_of_feature) {
      Log::Fatal("Binary file error: feature %d is incorrect, read count: %d", i, read_cnt);
    }
Guolin Ke's avatar
Guolin Ke committed
517
    dataset->feature_groups_.emplace_back(std::unique_ptr<FeatureGroup>(
518
519
      new FeatureGroup(buffer.data(),
                       *num_global_data,
520
                       *used_data_indices)));
Guolin Ke's avatar
Guolin Ke committed
521
  }
Guolin Ke's avatar
Guolin Ke committed
522
  dataset->feature_groups_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
523
  dataset->is_finish_load_ = true;
Guolin Ke's avatar
Guolin Ke committed
524
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
525
526
}

527

528
529
530
Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
                                               int** sample_indices, int num_col, const int* num_per_col,
                                               size_t total_sample_size, data_size_t num_data) {
531
532
533
534
535
  int num_total_features = num_col;
  if (Network::num_machines() > 1) {
    num_total_features = Network::GlobalSyncUpByMax(num_total_features);
  }
  std::vector<std::unique_ptr<BinMapper>> bin_mappers(num_total_features);
536
537
  // fill feature_names_ if not header
  if (feature_names_.empty()) {
538
    for (int i = 0; i < num_col; ++i) {
539
540
541
542
543
      std::stringstream str_buf;
      str_buf << "Column_" << i;
      feature_names_.push_back(str_buf.str());
    }
  }
Belinda Trotta's avatar
Belinda Trotta committed
544
  if (!config_.max_bin_by_feature.empty()) {
545
546
    CHECK_EQ(static_cast<size_t>(num_col), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
Belinda Trotta's avatar
Belinda Trotta committed
547
  }
548
549
550
551
552

  // get forced split
  std::string forced_bins_path = config_.forcedbins_filename;
  std::vector<std::vector<double>> forced_bin_bounds = DatasetLoader::GetForcedBins(forced_bins_path, num_col, categorical_features_);

Guolin Ke's avatar
Guolin Ke committed
553
  const data_size_t filter_cnt = static_cast<data_size_t>(
Guolin Ke's avatar
Guolin Ke committed
554
    static_cast<double>(config_.min_data_in_leaf * total_sample_size) / num_data);
555
556
557
558
559
560
561
562
563
564
565
566
567
  if (Network::num_machines() == 1) {
    // if only one machine, find bin locally
    OMP_INIT_EX();
    #pragma omp parallel for schedule(guided)
    for (int i = 0; i < num_col; ++i) {
      OMP_LOOP_EX_BEGIN();
      if (ignore_features_.count(i) > 0) {
        bin_mappers[i] = nullptr;
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(i)) {
        bin_type = BinType::CategoricalBin;
568
569
570
571
        bool feat_is_unconstrained = ((config_.monotone_constraints.size() == 0) || (config_.monotone_constraints[i] == 0));
        if (!feat_is_unconstrained) {
            Log::Fatal("The output cannot be monotone with respect to categorical features");
        }
572
573
      }
      bin_mappers[i].reset(new BinMapper());
Belinda Trotta's avatar
Belinda Trotta committed
574
575
      if (config_.max_bin_by_feature.empty()) {
        bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
576
                                config_.max_bin, config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter,
577
578
                                bin_type, config_.use_missing, config_.zero_as_missing,
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
579
580
      } else {
        bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
581
                                config_.max_bin_by_feature[i], config_.min_data_in_bin,
582
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
583
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
584
      }
585
586
587
588
589
590
591
592
593
594
595
596
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();
  } else {
    // if have multi-machines, need to find bin distributed
    // different machines will find bin for different features
    int num_machines = Network::num_machines();
    int rank = Network::rank();
    // start and len will store the process feature indices for different machines
    // machine i will find bins for features in [ start[i], start[i] + len[i] )
    std::vector<int> start(num_machines);
    std::vector<int> len(num_machines);
597
    int step = (num_total_features + num_machines - 1) / num_machines;
598
599
600
601
    if (step < 1) { step = 1; }

    start[0] = 0;
    for (int i = 0; i < num_machines - 1; ++i) {
602
      len[i] = std::min(step, num_total_features - start[i]);
603
604
      start[i + 1] = start[i] + len[i];
    }
605
    len[num_machines - 1] = num_total_features - start[num_machines - 1];
606
607
608
609
610
611
612
613
614
615
616
617
    OMP_INIT_EX();
    #pragma omp parallel for schedule(guided)
    for (int i = 0; i < len[rank]; ++i) {
      OMP_LOOP_EX_BEGIN();
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(start[rank] + i)) {
        bin_type = BinType::CategoricalBin;
      }
      bin_mappers[i].reset(new BinMapper());
618
619
620
      if (num_col <= start[rank] + i) {
        continue;
      }
Belinda Trotta's avatar
Belinda Trotta committed
621
      if (config_.max_bin_by_feature.empty()) {
622
623
        bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
                                total_sample_size, config_.max_bin, config_.min_data_in_bin,
624
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
625
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
626
      } else {
627
628
        bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
                                total_sample_size, config_.max_bin_by_feature[start[rank] + i],
629
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
630
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
631
      }
632
633
      OMP_LOOP_EX_END();
    }
Guolin Ke's avatar
Guolin Ke committed
634
    OMP_THROW_EX();
635
    comm_size_t self_buf_size = 0;
636
    for (int i = 0; i < len[rank]; ++i) {
637
638
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
639
      }
640
      self_buf_size += static_cast<comm_size_t>(bin_mappers[i]->SizesInByte());
Guolin Ke's avatar
Guolin Ke committed
641
    }
642
643
    std::vector<char> input_buffer(self_buf_size);
    auto cp_ptr = input_buffer.data();
644
645
646
647
    for (int i = 0; i < len[rank]; ++i) {
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
648
649
      bin_mappers[i]->CopyTo(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
650
651
652
      // free
      bin_mappers[i].reset(nullptr);
    }
653
654
655
656
    std::vector<comm_size_t> size_len = Network::GlobalArray(self_buf_size);
    std::vector<comm_size_t> size_start(num_machines, 0);
    for (int i = 1; i < num_machines; ++i) {
      size_start[i] = size_start[i - 1] + size_len[i - 1];
657
    }
658
659
    comm_size_t total_buffer_size = size_start[num_machines - 1] + size_len[num_machines - 1];
    std::vector<char> output_buffer(total_buffer_size);
660
    // gather global feature bin mappers
661
662
    Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), total_buffer_size);
    cp_ptr = output_buffer.data();
663
    // restore features bins from buffer
664
    for (int i = 0; i < num_total_features; ++i) {
665
666
667
668
669
      if (ignore_features_.count(i) > 0) {
        bin_mappers[i] = nullptr;
        continue;
      }
      bin_mappers[i].reset(new BinMapper());
670
671
      bin_mappers[i]->CopyFrom(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
672
    }
Guolin Ke's avatar
Guolin Ke committed
673
  }
Guolin Ke's avatar
Guolin Ke committed
674
  auto dataset = std::unique_ptr<Dataset>(new Dataset(num_data));
Guolin Ke's avatar
Guolin Ke committed
675
  dataset->Construct(&bin_mappers, num_total_features, forced_bin_bounds, sample_indices, sample_values, num_per_col, num_col, total_sample_size, config_);
676
  dataset->set_feature_names(feature_names_);
Guolin Ke's avatar
Guolin Ke committed
677
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
678
}
Guolin Ke's avatar
Guolin Ke committed
679
680
681
682
683
684


// ---- private functions ----

void DatasetLoader::CheckDataset(const Dataset* dataset) {
  if (dataset->num_data_ <= 0) {
Guolin Ke's avatar
Guolin Ke committed
685
    Log::Fatal("Data file %s is empty", dataset->data_filename_.c_str());
Guolin Ke's avatar
Guolin Ke committed
686
  }
687
688
  if (dataset->feature_names_.size() != static_cast<size_t>(dataset->num_total_features_)) {
    Log::Fatal("Size of feature name error, should be %d, got %d", dataset->num_total_features_,
689
               static_cast<int>(dataset->feature_names_.size()));
690
  }
Guolin Ke's avatar
Guolin Ke committed
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
  bool is_feature_order_by_group = true;
  int last_group = -1;
  int last_sub_feature = -1;
  // if features are ordered, not need to use hist_buf
  for (int i = 0; i < dataset->num_features_; ++i) {
    int group = dataset->feature2group_[i];
    int sub_feature = dataset->feature2subfeature_[i];
    if (group < last_group) {
      is_feature_order_by_group = false;
    } else if (group == last_group) {
      if (sub_feature <= last_sub_feature) {
        is_feature_order_by_group = false;
        break;
      }
    }
    last_group = group;
    last_sub_feature = sub_feature;
  }
  if (!is_feature_order_by_group) {
710
    Log::Fatal("Features in dataset should be ordered by group");
Guolin Ke's avatar
Guolin Ke committed
711
  }
Guolin Ke's avatar
Guolin Ke committed
712
713
714
}

std::vector<std::string> DatasetLoader::LoadTextDataToMemory(const char* filename, const Metadata& metadata,
715
716
                                                             int rank, int num_machines, int* num_global_data,
                                                             std::vector<data_size_t>* used_data_indices) {
717
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
718
  used_data_indices->clear();
Guolin Ke's avatar
Guolin Ke committed
719
  if (num_machines == 1 || config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
720
721
722
723
724
725
726
727
728
    // read all lines
    *num_global_data = text_reader.ReadAllLines();
  } else {  // need partition data
            // get query data
    const data_size_t* query_boundaries = metadata.query_boundaries();

    if (query_boundaries == nullptr) {
      // if not contain query data, minimal sample unit is one record
      *num_global_data = text_reader.ReadAndFilterLines([this, rank, num_machines](data_size_t) {
Guolin Ke's avatar
Guolin Ke committed
729
        if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
730
731
732
733
734
735
736
737
738
739
740
741
742
743
          return true;
        } else {
          return false;
        }
      }, used_data_indices);
    } else {
      // if contain query data, minimal sample unit is one query
      data_size_t num_queries = metadata.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      *num_global_data = text_reader.ReadAndFilterLines(
        [this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries]
      (data_size_t line_idx) {
        if (qid >= num_queries) {
744
745
          Log::Fatal("Current query exceeds the range of the query file,\n"
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
746
747
748
749
        }
        if (line_idx >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
750
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
751
752
753
754
755
756
757
758
759
760
761
762
            is_query_used = true;
          }
          ++qid;
        }
        return is_query_used;
      }, used_data_indices);
    }
  }
  return std::move(text_reader.Lines());
}

std::vector<std::string> DatasetLoader::SampleTextDataFromMemory(const std::vector<std::string>& data) {
Guolin Ke's avatar
Guolin Ke committed
763
  int sample_cnt = config_.bin_construct_sample_cnt;
764
765
  if (static_cast<size_t>(sample_cnt) > data.size()) {
    sample_cnt = static_cast<int>(data.size());
766
  }
767
  auto sample_indices = random_.Sample(static_cast<int>(data.size()), sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
768
  std::vector<std::string> out(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
769
770
  for (size_t i = 0; i < sample_indices.size(); ++i) {
    const size_t idx = sample_indices[i];
Guolin Ke's avatar
Guolin Ke committed
771
    out[i] = data[idx];
Guolin Ke's avatar
Guolin Ke committed
772
773
774
775
  }
  return out;
}

776
777
778
std::vector<std::string> DatasetLoader::SampleTextDataFromFile(const char* filename, const Metadata& metadata,
                                                               int rank, int num_machines, int* num_global_data,
                                                               std::vector<data_size_t>* used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
779
  const data_size_t sample_cnt = static_cast<data_size_t>(config_.bin_construct_sample_cnt);
780
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
781
  std::vector<std::string> out_data;
Guolin Ke's avatar
Guolin Ke committed
782
  if (num_machines == 1 || config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
783
    *num_global_data = static_cast<data_size_t>(text_reader.SampleFromFile(&random_, sample_cnt, &out_data));
Guolin Ke's avatar
Guolin Ke committed
784
785
786
787
788
789
790
  } else {  // need partition data
            // get query data
    const data_size_t* query_boundaries = metadata.query_boundaries();
    if (query_boundaries == nullptr) {
      // if not contain query file, minimal sample unit is one record
      *num_global_data = text_reader.SampleAndFilterFromFile([this, rank, num_machines]
      (data_size_t) {
Guolin Ke's avatar
Guolin Ke committed
791
        if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
792
793
794
795
          return true;
        } else {
          return false;
        }
Guolin Ke's avatar
Guolin Ke committed
796
      }, used_data_indices, &random_, sample_cnt, &out_data);
Guolin Ke's avatar
Guolin Ke committed
797
798
799
800
801
802
803
804
805
    } else {
      // if contain query file, minimal sample unit is one query
      data_size_t num_queries = metadata.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      *num_global_data = text_reader.SampleAndFilterFromFile(
        [this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries]
      (data_size_t line_idx) {
        if (qid >= num_queries) {
806
807
          Log::Fatal("Query id exceeds the range of the query file, "
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
808
809
810
811
        }
        if (line_idx >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
812
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
813
814
815
816
817
            is_query_used = true;
          }
          ++qid;
        }
        return is_query_used;
Guolin Ke's avatar
Guolin Ke committed
818
      }, used_data_indices, &random_, sample_cnt, &out_data);
Guolin Ke's avatar
Guolin Ke committed
819
820
821
822
823
    }
  }
  return out_data;
}

824
825
826
void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
                                                    const std::vector<std::string>& sample_data,
                                                    const Parser* parser, Dataset* dataset) {
Guolin Ke's avatar
Guolin Ke committed
827
  std::vector<std::vector<double>> sample_values;
Guolin Ke's avatar
Guolin Ke committed
828
  std::vector<std::vector<int>> sample_indices;
Guolin Ke's avatar
Guolin Ke committed
829
830
  std::vector<std::pair<int, double>> oneline_features;
  double label;
Guolin Ke's avatar
Guolin Ke committed
831
  for (int i = 0; i < static_cast<int>(sample_data.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
832
833
834
835
    oneline_features.clear();
    // parse features
    parser->ParseOneLine(sample_data[i].c_str(), &oneline_features, &label);
    for (std::pair<int, double>& inner_data : oneline_features) {
836
      if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
Guolin Ke's avatar
Guolin Ke committed
837
838
        sample_values.resize(inner_data.first + 1);
        sample_indices.resize(inner_data.first + 1);
839
      }
Guolin Ke's avatar
Guolin Ke committed
840
      if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
Guolin Ke's avatar
Guolin Ke committed
841
842
        sample_values[inner_data.first].emplace_back(inner_data.second);
        sample_indices[inner_data.first].emplace_back(i);
Guolin Ke's avatar
Guolin Ke committed
843
844
845
846
      }
    }
  }

Guolin Ke's avatar
Guolin Ke committed
847
  dataset->feature_groups_.clear();
848
849
850
851
852
  dataset->num_total_features_ = std::max(static_cast<int>(sample_values.size()), parser->NumFeatures());
  if (num_machines > 1) {
    dataset->num_total_features_ = Network::GlobalSyncUpByMax(dataset->num_total_features_);
  }
  if (!feature_names_.empty()) {
853
    CHECK_EQ(dataset->num_total_features_, static_cast<int>(feature_names_.size()));
854
  }
Guolin Ke's avatar
Guolin Ke committed
855

Belinda Trotta's avatar
Belinda Trotta committed
856
  if (!config_.max_bin_by_feature.empty()) {
857
858
    CHECK_EQ(static_cast<size_t>(dataset->num_total_features_), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
Belinda Trotta's avatar
Belinda Trotta committed
859
860
  }

861
862
  // get forced split
  std::string forced_bins_path = config_.forcedbins_filename;
863
864
  std::vector<std::vector<double>> forced_bin_bounds = DatasetLoader::GetForcedBins(forced_bins_path,
                                                                                    dataset->num_total_features_,
865
866
                                                                                    categorical_features_);

Guolin Ke's avatar
Guolin Ke committed
867
868
869
870
871
872
  // check the range of label_idx, weight_idx and group_idx
  CHECK(label_idx_ >= 0 && label_idx_ <= dataset->num_total_features_);
  CHECK(weight_idx_ < 0 || weight_idx_ < dataset->num_total_features_);
  CHECK(group_idx_ < 0 || group_idx_ < dataset->num_total_features_);

  // fill feature_names_ if not header
Guolin Ke's avatar
Guolin Ke committed
873
  if (feature_names_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
874
875
876
877
878
879
    for (int i = 0; i < dataset->num_total_features_; ++i) {
      std::stringstream str_buf;
      str_buf << "Column_" << i;
      feature_names_.push_back(str_buf.str());
    }
  }
880
  dataset->set_feature_names(feature_names_);
Guolin Ke's avatar
Guolin Ke committed
881
  std::vector<std::unique_ptr<BinMapper>> bin_mappers(dataset->num_total_features_);
Guolin Ke's avatar
Guolin Ke committed
882
  const data_size_t filter_cnt = static_cast<data_size_t>(
Guolin Ke's avatar
Guolin Ke committed
883
    static_cast<double>(config_.min_data_in_leaf* sample_data.size()) / dataset->num_data_);
Guolin Ke's avatar
Guolin Ke committed
884
885
886
  // start find bins
  if (num_machines == 1) {
    // if only one machine, find bin locally
887
    OMP_INIT_EX();
888
    #pragma omp parallel for schedule(guided)
Guolin Ke's avatar
Guolin Ke committed
889
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
890
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
891
      if (ignore_features_.count(i) > 0) {
Guolin Ke's avatar
Guolin Ke committed
892
        bin_mappers[i] = nullptr;
Guolin Ke's avatar
Guolin Ke committed
893
894
        continue;
      }
895
896
897
898
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(i)) {
        bin_type = BinType::CategoricalBin;
      }
Guolin Ke's avatar
Guolin Ke committed
899
      bin_mappers[i].reset(new BinMapper());
Belinda Trotta's avatar
Belinda Trotta committed
900
901
      if (config_.max_bin_by_feature.empty()) {
        bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
902
                                sample_data.size(), config_.max_bin, config_.min_data_in_bin,
903
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
904
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
905
906
      } else {
        bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
907
                                sample_data.size(), config_.max_bin_by_feature[i],
908
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
909
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
910
      }
911
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
912
    }
913
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
914
915
  } else {
    // start and len will store the process feature indices for different machines
916
    // machine i will find bins for features in [ start[i], start[i] + len[i] )
Guolin Ke's avatar
Guolin Ke committed
917
918
    std::vector<int> start(num_machines);
    std::vector<int> len(num_machines);
919
    int step = (dataset->num_total_features_ + num_machines - 1) / num_machines;
Guolin Ke's avatar
Guolin Ke committed
920
921
922
923
    if (step < 1) { step = 1; }

    start[0] = 0;
    for (int i = 0; i < num_machines - 1; ++i) {
924
      len[i] = std::min(step, dataset->num_total_features_ - start[i]);
Guolin Ke's avatar
Guolin Ke committed
925
926
      start[i + 1] = start[i] + len[i];
    }
927
    len[num_machines - 1] = dataset->num_total_features_ - start[num_machines - 1];
928
    OMP_INIT_EX();
929
    #pragma omp parallel for schedule(guided)
930
    for (int i = 0; i < len[rank]; ++i) {
931
      OMP_LOOP_EX_BEGIN();
932
933
934
935
936
937
938
939
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(start[rank] + i)) {
        bin_type = BinType::CategoricalBin;
      }
      bin_mappers[i].reset(new BinMapper());
Nikita Titov's avatar
Nikita Titov committed
940
      if (static_cast<int>(sample_values.size()) <= start[rank] + i) {
941
942
        continue;
      }
Belinda Trotta's avatar
Belinda Trotta committed
943
      if (config_.max_bin_by_feature.empty()) {
944
        bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
Belinda Trotta's avatar
Belinda Trotta committed
945
                                static_cast<int>(sample_values[start[rank] + i].size()),
946
                                sample_data.size(), config_.max_bin, config_.min_data_in_bin,
947
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
948
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
949
      } else {
950
        bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
Belinda Trotta's avatar
Belinda Trotta committed
951
                                static_cast<int>(sample_values[start[rank] + i].size()),
952
                                sample_data.size(), config_.max_bin_by_feature[i],
953
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type,
954
                                config_.use_missing, config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
955
      }
956
      OMP_LOOP_EX_END();
957
    }
958
    OMP_THROW_EX();
959
    comm_size_t self_buf_size = 0;
Guolin Ke's avatar
Guolin Ke committed
960
    for (int i = 0; i < len[rank]; ++i) {
961
962
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
Guolin Ke's avatar
Guolin Ke committed
963
      }
964
      self_buf_size += static_cast<comm_size_t>(bin_mappers[i]->SizesInByte());
Guolin Ke's avatar
Guolin Ke committed
965
    }
966
967
    std::vector<char> input_buffer(self_buf_size);
    auto cp_ptr = input_buffer.data();
Guolin Ke's avatar
Guolin Ke committed
968
    for (int i = 0; i < len[rank]; ++i) {
Guolin Ke's avatar
Guolin Ke committed
969
970
971
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
972
973
      bin_mappers[i]->CopyTo(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
974
975
      // free
      bin_mappers[i].reset(nullptr);
Guolin Ke's avatar
Guolin Ke committed
976
    }
977
978
979
980
    std::vector<comm_size_t> size_len = Network::GlobalArray(self_buf_size);
    std::vector<comm_size_t> size_start(num_machines, 0);
    for (int i = 1; i < num_machines; ++i) {
      size_start[i] = size_start[i - 1] + size_len[i - 1];
Guolin Ke's avatar
Guolin Ke committed
981
    }
982
983
    comm_size_t total_buffer_size = size_start[num_machines - 1] + size_len[num_machines - 1];
    std::vector<char> output_buffer(total_buffer_size);
Guolin Ke's avatar
Guolin Ke committed
984
    // gather global feature bin mappers
985
986
    Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), total_buffer_size);
    cp_ptr = output_buffer.data();
Guolin Ke's avatar
Guolin Ke committed
987
    // restore features bins from buffer
988
    for (int i = 0; i < dataset->num_total_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
989
      if (ignore_features_.count(i) > 0) {
Guolin Ke's avatar
Guolin Ke committed
990
        bin_mappers[i] = nullptr;
Guolin Ke's avatar
Guolin Ke committed
991
992
        continue;
      }
Guolin Ke's avatar
Guolin Ke committed
993
      bin_mappers[i].reset(new BinMapper());
994
995
      bin_mappers[i]->CopyFrom(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
Guolin Ke's avatar
Guolin Ke committed
996
997
    }
  }
998
  dataset->Construct(&bin_mappers, dataset->num_total_features_, forced_bin_bounds, Common::Vector2Ptr<int>(&sample_indices).data(),
Guolin Ke's avatar
Guolin Ke committed
999
                     Common::Vector2Ptr<double>(&sample_values).data(),
1000
                     Common::VectorSize<int>(sample_indices).data(), static_cast<int>(sample_indices.size()), sample_data.size(), config_);
Guolin Ke's avatar
Guolin Ke committed
1001
1002
1003
}

/*! \brief Extract local features from memory */
Guolin Ke's avatar
Guolin Ke committed
1004
void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>* text_data, const Parser* parser, Dataset* dataset) {
Guolin Ke's avatar
Guolin Ke committed
1005
1006
  std::vector<std::pair<int, double>> oneline_features;
  double tmp_label = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
1007
  auto& ref_text_data = *text_data;
Guolin Ke's avatar
Guolin Ke committed
1008
  if (predict_fun_ == nullptr) {
1009
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1010
    // if doesn't need to prediction with initial model
1011
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
Guolin Ke's avatar
Guolin Ke committed
1012
    for (data_size_t i = 0; i < dataset->num_data_; ++i) {
1013
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1014
1015
1016
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
Guolin Ke's avatar
Guolin Ke committed
1017
      parser->ParseOneLine(ref_text_data[i].c_str(), &oneline_features, &tmp_label);
Guolin Ke's avatar
Guolin Ke committed
1018
      // set label
1019
      dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1020
      // free processed line:
Guolin Ke's avatar
Guolin Ke committed
1021
      ref_text_data[i].clear();
Guolin Ke's avatar
Guolin Ke committed
1022
1023
      // shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
      // text_reader_->Lines()[i].shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
1024
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1025
1026
      // push data
      for (auto& inner_data : oneline_features) {
1027
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1028
1029
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1030
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1031
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1032
1033
1034
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
          dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
Guolin Ke's avatar
Guolin Ke committed
1035
1036
        } else {
          if (inner_data.first == weight_idx_) {
1037
            dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1038
1039
1040
1041
1042
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1043
      dataset->FinishOneRow(tid, i, is_feature_added);
1044
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1045
    }
1046
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1047
  } else {
1048
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1049
    // if need to prediction with initial model
1050
    std::vector<double> init_score(dataset->num_data_ * num_class_);
1051
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
Guolin Ke's avatar
Guolin Ke committed
1052
    for (data_size_t i = 0; i < dataset->num_data_; ++i) {
1053
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1054
1055
1056
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
Guolin Ke's avatar
Guolin Ke committed
1057
      parser->ParseOneLine(ref_text_data[i].c_str(), &oneline_features, &tmp_label);
Guolin Ke's avatar
Guolin Ke committed
1058
      // set initial score
Guolin Ke's avatar
Guolin Ke committed
1059
1060
      std::vector<double> oneline_init_score(num_class_);
      predict_fun_(oneline_features, oneline_init_score.data());
1061
      for (int k = 0; k < num_class_; ++k) {
1062
        init_score[k * dataset->num_data_ + i] = static_cast<double>(oneline_init_score[k]);
Guolin Ke's avatar
Guolin Ke committed
1063
1064
      }
      // set label
1065
      dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1066
1067
1068
1069
1070
      // free processed line:
      text_data[i].clear();
      // shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
      // text_reader_->Lines()[i].shrink_to_fit();
      // push data
Guolin Ke's avatar
Guolin Ke committed
1071
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1072
      for (auto& inner_data : oneline_features) {
1073
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1074
1075
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1076
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1077
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1078
1079
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
1080
          dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
Guolin Ke's avatar
Guolin Ke committed
1081
1082
        } else {
          if (inner_data.first == weight_idx_) {
1083
            dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1084
1085
1086
1087
1088
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1089
      dataset->FinishOneRow(tid, i, is_feature_added);
1090
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1091
    }
1092
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1093
    // metadata_ will manage space of init_score
1094
    dataset->metadata_.SetInitScore(init_score.data(), dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1095
  }
Guolin Ke's avatar
Guolin Ke committed
1096
  dataset->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1097
  // text data can be free after loaded feature values
Guolin Ke's avatar
Guolin Ke committed
1098
  text_data->clear();
Guolin Ke's avatar
Guolin Ke committed
1099
1100
1101
}

/*! \brief Extract local features from file */
1102
1103
void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser* parser,
                                            const std::vector<data_size_t>& used_data_indices, Dataset* dataset) {
1104
  std::vector<double> init_score;
Guolin Ke's avatar
Guolin Ke committed
1105
  if (predict_fun_ != nullptr) {
1106
    init_score = std::vector<double>(dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1107
1108
1109
1110
1111
1112
  }
  std::function<void(data_size_t, const std::vector<std::string>&)> process_fun =
    [this, &init_score, &parser, &dataset]
  (data_size_t start_idx, const std::vector<std::string>& lines) {
    std::vector<std::pair<int, double>> oneline_features;
    double tmp_label = 0.0f;
1113
    OMP_INIT_EX();
1114
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
Guolin Ke's avatar
Guolin Ke committed
1115
    for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) {
1116
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1117
1118
1119
1120
1121
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
      parser->ParseOneLine(lines[i].c_str(), &oneline_features, &tmp_label);
      // set initial score
Guolin Ke's avatar
Guolin Ke committed
1122
      if (!init_score.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1123
1124
        std::vector<double> oneline_init_score(num_class_);
        predict_fun_(oneline_features, oneline_init_score.data());
1125
        for (int k = 0; k < num_class_; ++k) {
1126
          init_score[k * dataset->num_data_ + start_idx + i] = static_cast<double>(oneline_init_score[k]);
Guolin Ke's avatar
Guolin Ke committed
1127
1128
1129
        }
      }
      // set label
1130
      dataset->metadata_.SetLabelAt(start_idx + i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1131
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1132
1133
      // push data
      for (auto& inner_data : oneline_features) {
1134
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1135
1136
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1137
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1138
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1139
1140
1141
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
          dataset->feature_groups_[group]->PushData(tid, sub_feature, start_idx + i, inner_data.second);
Guolin Ke's avatar
Guolin Ke committed
1142
1143
        } else {
          if (inner_data.first == weight_idx_) {
1144
            dataset->metadata_.SetWeightAt(start_idx + i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1145
1146
1147
1148
1149
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(start_idx + i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1150
      dataset->FinishOneRow(tid, i, is_feature_added);
1151
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1152
    }
1153
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1154
  };
1155
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
1156
  if (!used_data_indices.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1157
1158
1159
1160
1161
1162
1163
1164
    // only need part of data
    text_reader.ReadPartAndProcessParallel(used_data_indices, process_fun);
  } else {
    // need full data
    text_reader.ReadAllAndProcessParallel(process_fun);
  }

  // metadata_ will manage space of init_score
Guolin Ke's avatar
Guolin Ke committed
1165
  if (!init_score.empty()) {
1166
    dataset->metadata_.SetInitScore(init_score.data(), dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1167
  }
Guolin Ke's avatar
Guolin Ke committed
1168
  dataset->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1169
1170
1171
}

/*! \brief Check can load from binary file */
1172
std::string DatasetLoader::CheckCanLoadFromBin(const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
1173
1174
1175
  std::string bin_filename(filename);
  bin_filename.append(".bin");

1176
  auto reader = VirtualFileReader::Make(bin_filename.c_str());
Guolin Ke's avatar
Guolin Ke committed
1177

1178
  if (!reader->Init()) {
1179
    bin_filename = std::string(filename);
1180
1181
    reader = VirtualFileReader::Make(bin_filename.c_str());
    if (!reader->Init()) {
1182
      Log::Fatal("Cannot open data file %s", bin_filename.c_str());
1183
    }
1184
  }
1185
1186
1187
1188
1189

  size_t buffer_size = 256;
  auto buffer = std::vector<char>(buffer_size);
  // read size of token
  size_t size_of_token = std::strlen(Dataset::binary_file_token);
1190
  size_t read_cnt = reader->Read(buffer.data(), size_of_token);
1191
1192
  if (read_cnt == size_of_token
      && std::string(buffer.data()) == std::string(Dataset::binary_file_token)) {
1193
    return bin_filename;
Guolin Ke's avatar
Guolin Ke committed
1194
  } else {
1195
    return std::string();
Guolin Ke's avatar
Guolin Ke committed
1196
1197
1198
  }
}

1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236


std::vector<std::vector<double>> DatasetLoader::GetForcedBins(std::string forced_bins_path, int num_total_features,
                                                              const std::unordered_set<int>& categorical_features) {
  std::vector<std::vector<double>> forced_bins(num_total_features, std::vector<double>());
  if (forced_bins_path != "") {
    std::ifstream forced_bins_stream(forced_bins_path.c_str());
    if (forced_bins_stream.fail()) {
      Log::Warning("Could not open %s. Will ignore.", forced_bins_path.c_str());
    } else {
      std::stringstream buffer;
      buffer << forced_bins_stream.rdbuf();
      std::string err;
      Json forced_bins_json = Json::parse(buffer.str(), err);
      CHECK(forced_bins_json.is_array());
      std::vector<Json> forced_bins_arr = forced_bins_json.array_items();
      for (size_t i = 0; i < forced_bins_arr.size(); ++i) {
        int feature_num = forced_bins_arr[i]["feature"].int_value();
        CHECK(feature_num < num_total_features);
        if (categorical_features.count(feature_num)) {
          Log::Warning("Feature %d is categorical. Will ignore forced bins for this  feature.", feature_num);
        } else {
          std::vector<Json> bounds_arr = forced_bins_arr[i]["bin_upper_bound"].array_items();
          for (size_t j = 0; j < bounds_arr.size(); ++j) {
            forced_bins[feature_num].push_back(bounds_arr[j].number_value());
          }
        }
      }
      // remove duplicates
      for (int i = 0; i < num_total_features; ++i) {
        auto new_end = std::unique(forced_bins[i].begin(), forced_bins[i].end());
        forced_bins[i].erase(new_end, forced_bins[i].end());
      }
    }
  }
  return forced_bins;
}

1237
}  // namespace LightGBM