dataset_loader.cpp 55.2 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
#include <LightGBM/dataset_loader.h>

Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/network.h>
8
#include <LightGBM/utils/array_args.h>
9
#include <LightGBM/utils/json11.h>
10
11
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
12

13
14
#include <fstream>

Guolin Ke's avatar
Guolin Ke committed
15
16
namespace LightGBM {

17
18
using json11::Json;

Guolin Ke's avatar
Guolin Ke committed
19
20
DatasetLoader::DatasetLoader(const Config& io_config, const PredictFunction& predict_fun, int num_class, const char* filename)
  :config_(io_config), random_(config_.data_random_seed), predict_fun_(predict_fun), num_class_(num_class) {
Guolin Ke's avatar
Guolin Ke committed
21
22
23
24
  label_idx_ = 0;
  weight_idx_ = NO_SPECIFIC;
  group_idx_ = NO_SPECIFIC;
  SetHeader(filename);
Guolin Ke's avatar
Guolin Ke committed
25
26
27
28
29
}

DatasetLoader::~DatasetLoader() {
}

Guolin Ke's avatar
Guolin Ke committed
30
void DatasetLoader::SetHeader(const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
31
  std::unordered_map<std::string, int> name2idx;
Guolin Ke's avatar
Guolin Ke committed
32
33
  std::string name_prefix("name:");
  if (filename != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
34
    TextReader<data_size_t> text_reader(filename, config_.header);
Guolin Ke's avatar
Guolin Ke committed
35

Guolin Ke's avatar
Guolin Ke committed
36
    // get column names
Guolin Ke's avatar
Guolin Ke committed
37
    if (config_.header) {
Guolin Ke's avatar
Guolin Ke committed
38
      std::string first_line = text_reader.first_line();
39
      feature_names_ = Common::Split(first_line.c_str(), "\t,");
Guolin Ke's avatar
Guolin Ke committed
40
41
    }

Guolin Ke's avatar
Guolin Ke committed
42
    // load label idx first
Guolin Ke's avatar
Guolin Ke committed
43
44
45
    if (config_.label_column.size() > 0) {
      if (Common::StartsWith(config_.label_column, name_prefix)) {
        std::string name = config_.label_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
46
47
48
49
50
51
52
53
54
55
        label_idx_ = -1;
        for (int i = 0; i < static_cast<int>(feature_names_.size()); ++i) {
          if (name == feature_names_[i]) {
            label_idx_ = i;
            break;
          }
        }
        if (label_idx_ >= 0) {
          Log::Info("Using column %s as label", name.c_str());
        } else {
56
57
          Log::Fatal("Could not find label column %s in data file \n"
                     "or data file doesn't contain header", name.c_str());
Guolin Ke's avatar
Guolin Ke committed
58
        }
Guolin Ke's avatar
Guolin Ke committed
59
      } else {
Guolin Ke's avatar
Guolin Ke committed
60
        if (!Common::AtoiAndCheck(config_.label_column.c_str(), &label_idx_)) {
61
62
63
          Log::Fatal("label_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
64
65
        }
        Log::Info("Using column number %d as label", label_idx_);
Guolin Ke's avatar
Guolin Ke committed
66
67
      }
    }
Guolin Ke's avatar
Guolin Ke committed
68

Guolin Ke's avatar
Guolin Ke committed
69
    if (!feature_names_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
70
71
72
73
      // erase label column name
      feature_names_.erase(feature_names_.begin() + label_idx_);
      for (size_t i = 0; i < feature_names_.size(); ++i) {
        name2idx[feature_names_[i]] = static_cast<int>(i);
Guolin Ke's avatar
Guolin Ke committed
74
      }
Guolin Ke's avatar
Guolin Ke committed
75
76
77
    }

    // load ignore columns
Guolin Ke's avatar
Guolin Ke committed
78
79
80
    if (config_.ignore_column.size() > 0) {
      if (Common::StartsWith(config_.ignore_column, name_prefix)) {
        std::string names = config_.ignore_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
81
82
83
84
85
86
87
88
89
        for (auto name : Common::Split(names.c_str(), ',')) {
          if (name2idx.count(name) > 0) {
            int tmp = name2idx[name];
            ignore_features_.emplace(tmp);
          } else {
            Log::Fatal("Could not find ignore column %s in data file", name.c_str());
          }
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
90
        for (auto token : Common::Split(config_.ignore_column.c_str(), ',')) {
Guolin Ke's avatar
Guolin Ke committed
91
92
          int tmp = 0;
          if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
93
94
95
            Log::Fatal("ignore_column is not a number,\n"
                       "if you want to use a column name,\n"
                       "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
96
97
          }
          ignore_features_.emplace(tmp);
Guolin Ke's avatar
Guolin Ke committed
98
99
100
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
101
    // load weight idx
Guolin Ke's avatar
Guolin Ke committed
102
103
104
    if (config_.weight_column.size() > 0) {
      if (Common::StartsWith(config_.weight_column, name_prefix)) {
        std::string name = config_.weight_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
105
106
107
108
109
110
        if (name2idx.count(name) > 0) {
          weight_idx_ = name2idx[name];
          Log::Info("Using column %s as weight", name.c_str());
        } else {
          Log::Fatal("Could not find weight column %s in data file", name.c_str());
        }
Guolin Ke's avatar
Guolin Ke committed
111
      } else {
Guolin Ke's avatar
Guolin Ke committed
112
        if (!Common::AtoiAndCheck(config_.weight_column.c_str(), &weight_idx_)) {
113
114
115
          Log::Fatal("weight_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
116
117
        }
        Log::Info("Using column number %d as weight", weight_idx_);
Guolin Ke's avatar
Guolin Ke committed
118
      }
Guolin Ke's avatar
Guolin Ke committed
119
      ignore_features_.emplace(weight_idx_);
Guolin Ke's avatar
Guolin Ke committed
120
    }
Guolin Ke's avatar
Guolin Ke committed
121
    // load group idx
Guolin Ke's avatar
Guolin Ke committed
122
123
124
    if (config_.group_column.size() > 0) {
      if (Common::StartsWith(config_.group_column, name_prefix)) {
        std::string name = config_.group_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
125
126
127
128
129
130
131
        if (name2idx.count(name) > 0) {
          group_idx_ = name2idx[name];
          Log::Info("Using column %s as group/query id", name.c_str());
        } else {
          Log::Fatal("Could not find group/query column %s in data file", name.c_str());
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
132
        if (!Common::AtoiAndCheck(config_.group_column.c_str(), &group_idx_)) {
133
134
135
          Log::Fatal("group_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
136
137
138
139
        }
        Log::Info("Using column number %d as group/query id", group_idx_);
      }
      ignore_features_.emplace(group_idx_);
Guolin Ke's avatar
Guolin Ke committed
140
141
    }
  }
Guolin Ke's avatar
Guolin Ke committed
142
143
144
  if (config_.categorical_feature.size() > 0) {
    if (Common::StartsWith(config_.categorical_feature, name_prefix)) {
      std::string names = config_.categorical_feature.substr(name_prefix.size());
145
146
147
148
149
      for (auto name : Common::Split(names.c_str(), ',')) {
        if (name2idx.count(name) > 0) {
          int tmp = name2idx[name];
          categorical_features_.emplace(tmp);
        } else {
Guolin Ke's avatar
Guolin Ke committed
150
          Log::Fatal("Could not find categorical_feature %s in data file", name.c_str());
151
152
153
        }
      }
    } else {
Guolin Ke's avatar
Guolin Ke committed
154
      for (auto token : Common::Split(config_.categorical_feature.c_str(), ',')) {
155
156
        int tmp = 0;
        if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
Guolin Ke's avatar
Guolin Ke committed
157
          Log::Fatal("categorical_feature is not a number,\n"
158
159
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
160
161
162
163
164
        }
        categorical_features_.emplace(tmp);
      }
    }
  }
Guolin Ke's avatar
Guolin Ke committed
165
166
}

167
168
169
170
171
172
173
174
175
176
void CheckSampleSize(size_t sample_cnt, size_t num_data) {
  if (static_cast<double>(sample_cnt) / num_data < 0.2f &&
      sample_cnt < 100000) {
    Log::Warning(
        "Using too small ``bin_construct_sample_cnt`` may encounter "
        "unexpected "
        "errors and poor accuracy.");
  }
}

177
Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_machines) {
Guolin Ke's avatar
Guolin Ke committed
178
  // don't support query id in data file when training in parallel
Guolin Ke's avatar
Guolin Ke committed
179
  if (num_machines > 1 && !config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
180
    if (group_idx_ > 0) {
181
182
      Log::Fatal("Using a query id without pre-partitioning the data file is not supported for parallel training.\n"
                 "Please use an additional query file or pre-partition the data");
Guolin Ke's avatar
Guolin Ke committed
183
184
    }
  }
Guolin Ke's avatar
Guolin Ke committed
185
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
Guolin Ke's avatar
Guolin Ke committed
186
187
  data_size_t num_global_data = 0;
  std::vector<data_size_t> used_data_indices;
188
  auto bin_filename = CheckCanLoadFromBin(filename);
189
  bool is_load_from_binary = false;
190
  if (bin_filename.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
191
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_));
Guolin Ke's avatar
Guolin Ke committed
192
193
194
195
    if (parser == nullptr) {
      Log::Fatal("Could not recognize data format of %s", filename);
    }
    dataset->data_filename_ = filename;
Guolin Ke's avatar
Guolin Ke committed
196
    dataset->label_idx_ = label_idx_;
197
    dataset->metadata_.Init(filename);
Guolin Ke's avatar
Guolin Ke committed
198
    if (!config_.two_round) {
Guolin Ke's avatar
Guolin Ke committed
199
      // read data to memory
200
      auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
201
202
203
      dataset->num_data_ = static_cast<data_size_t>(text_data.size());
      // sample data
      auto sample_data = SampleTextDataFromMemory(text_data);
204
205
      CheckSampleSize(sample_data.size(),
                      static_cast<size_t>(dataset->num_data_));
Guolin Ke's avatar
Guolin Ke committed
206
      // construct feature bin mappers
Guolin Ke's avatar
Guolin Ke committed
207
      ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
208
      // initialize label
209
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
210
      // extract features
Guolin Ke's avatar
Guolin Ke committed
211
      ExtractFeaturesFromMemory(&text_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
212
213
214
215
216
217
218
219
220
      text_data.clear();
    } else {
      // sample data from file
      auto sample_data = SampleTextDataFromFile(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
      if (used_data_indices.size() > 0) {
        dataset->num_data_ = static_cast<data_size_t>(used_data_indices.size());
      } else {
        dataset->num_data_ = num_global_data;
      }
221
222
      CheckSampleSize(sample_data.size(),
                      static_cast<size_t>(dataset->num_data_));
Guolin Ke's avatar
Guolin Ke committed
223
      // construct feature bin mappers
Guolin Ke's avatar
Guolin Ke committed
224
      ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
225
      // initialize label
226
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
227
      Log::Debug("Making second pass...");
Guolin Ke's avatar
Guolin Ke committed
228
      // extract features
Guolin Ke's avatar
Guolin Ke committed
229
      ExtractFeaturesFromFile(filename, parser.get(), used_data_indices, dataset.get());
Guolin Ke's avatar
Guolin Ke committed
230
231
232
    }
  } else {
    // load data from binary file
233
234
    is_load_from_binary = true;
    Log::Info("Load from binary file %s", bin_filename.c_str());
235
    dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), rank, num_machines, &num_global_data, &used_data_indices));
Guolin Ke's avatar
Guolin Ke committed
236
237
238
239
  }
  // check meta data
  dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);
  // need to check training data
240
241
  CheckDataset(dataset.get(), is_load_from_binary);

Guolin Ke's avatar
Guolin Ke committed
242
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
243
244
245
246
}



247
Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, const Dataset* train_data) {
Guolin Ke's avatar
Guolin Ke committed
248
249
  data_size_t num_global_data = 0;
  std::vector<data_size_t> used_data_indices;
Guolin Ke's avatar
Guolin Ke committed
250
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
251
252
  auto bin_filename = CheckCanLoadFromBin(filename);
  if (bin_filename.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
253
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_));
Guolin Ke's avatar
Guolin Ke committed
254
255
256
257
    if (parser == nullptr) {
      Log::Fatal("Could not recognize data format of %s", filename);
    }
    dataset->data_filename_ = filename;
Guolin Ke's avatar
Guolin Ke committed
258
    dataset->label_idx_ = label_idx_;
259
    dataset->metadata_.Init(filename);
Guolin Ke's avatar
Guolin Ke committed
260
    if (!config_.two_round) {
Guolin Ke's avatar
Guolin Ke committed
261
262
263
264
      // read data in memory
      auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, 0, 1, &num_global_data, &used_data_indices);
      dataset->num_data_ = static_cast<data_size_t>(text_data.size());
      // initialize label
265
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
266
      dataset->CreateValid(train_data);
Guolin Ke's avatar
Guolin Ke committed
267
      // extract features
Guolin Ke's avatar
Guolin Ke committed
268
      ExtractFeaturesFromMemory(&text_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
269
270
      text_data.clear();
    } else {
Guolin Ke's avatar
Guolin Ke committed
271
      TextReader<data_size_t> text_reader(filename, config_.header);
Guolin Ke's avatar
Guolin Ke committed
272
273
274
275
      // Get number of lines of data file
      dataset->num_data_ = static_cast<data_size_t>(text_reader.CountLine());
      num_global_data = dataset->num_data_;
      // initialize label
276
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
277
      dataset->CreateValid(train_data);
Guolin Ke's avatar
Guolin Ke committed
278
      // extract features
Guolin Ke's avatar
Guolin Ke committed
279
      ExtractFeaturesFromFile(filename, parser.get(), used_data_indices, dataset.get());
Guolin Ke's avatar
Guolin Ke committed
280
281
282
    }
  } else {
    // load data from binary file
283
    dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), 0, 1, &num_global_data, &used_data_indices));
Guolin Ke's avatar
Guolin Ke committed
284
285
286
287
  }
  // not need to check validation data
  // check meta data
  dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
288
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
289
290
}

291
292
293
Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* bin_filename,
                                        int rank, int num_machines, int* num_global_data,
                                        std::vector<data_size_t>* used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
294
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
295
  auto reader = VirtualFileReader::Make(bin_filename);
Guolin Ke's avatar
Guolin Ke committed
296
  dataset->data_filename_ = data_filename;
297
  if (!reader->Init()) {
Guolin Ke's avatar
Guolin Ke committed
298
299
300
301
302
    Log::Fatal("Could not read binary data from %s", bin_filename);
  }

  // buffer to read binary file
  size_t buffer_size = 16 * 1024 * 1024;
Guolin Ke's avatar
Guolin Ke committed
303
  auto buffer = std::vector<char>(buffer_size);
304

305
306
  // check token
  size_t size_of_token = std::strlen(Dataset::binary_file_token);
307
308
309
310
  size_t read_cnt = reader->Read(
      buffer.data(),
      VirtualFileWriter::AlignedSize(sizeof(char) * size_of_token));
  if (read_cnt < sizeof(char) * size_of_token) {
311
312
313
    Log::Fatal("Binary file error: token has the wrong size");
  }
  if (std::string(buffer.data()) != std::string(Dataset::binary_file_token)) {
314
    Log::Fatal("Input file is not LightGBM binary file");
315
  }
Guolin Ke's avatar
Guolin Ke committed
316
317

  // read size of header
318
  read_cnt = reader->Read(buffer.data(), sizeof(size_t));
Guolin Ke's avatar
Guolin Ke committed
319

320
  if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
321
322
323
    Log::Fatal("Binary file error: header has the wrong size");
  }

Guolin Ke's avatar
Guolin Ke committed
324
  size_t size_of_head = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
325
326
327
328

  // re-allocmate space if not enough
  if (size_of_head > buffer_size) {
    buffer_size = size_of_head;
Guolin Ke's avatar
Guolin Ke committed
329
    buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
330
331
  }
  // read header
332
  read_cnt = reader->Read(buffer.data(), size_of_head);
Guolin Ke's avatar
Guolin Ke committed
333
334
335
336
337

  if (read_cnt != size_of_head) {
    Log::Fatal("Binary file error: header is incorrect");
  }
  // get header
Guolin Ke's avatar
Guolin Ke committed
338
  const char* mem_ptr = buffer.data();
Guolin Ke's avatar
Guolin Ke committed
339
  dataset->num_data_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
340
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->num_data_));
Guolin Ke's avatar
Guolin Ke committed
341
  dataset->num_features_ = *(reinterpret_cast<const int*>(mem_ptr));
342
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->num_features_));
Guolin Ke's avatar
Guolin Ke committed
343
  dataset->num_total_features_ = *(reinterpret_cast<const int*>(mem_ptr));
344
345
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(dataset->num_total_features_));
Guolin Ke's avatar
Guolin Ke committed
346
  dataset->label_idx_ = *(reinterpret_cast<const int*>(mem_ptr));
347
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->label_idx_));
348
  dataset->max_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
349
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->max_bin_));
350
  dataset->bin_construct_sample_cnt_ = *(reinterpret_cast<const int*>(mem_ptr));
351
352
  mem_ptr += VirtualFileWriter::AlignedSize(
      sizeof(dataset->bin_construct_sample_cnt_));
353
  dataset->min_data_in_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
354
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->min_data_in_bin_));
355
  dataset->use_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
356
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->use_missing_));
357
  dataset->zero_as_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
358
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->zero_as_missing_));
Guolin Ke's avatar
Guolin Ke committed
359
360
  const int* tmp_feature_map = reinterpret_cast<const int*>(mem_ptr);
  dataset->used_feature_map_.clear();
Guolin Ke's avatar
Guolin Ke committed
361
  for (int i = 0; i < dataset->num_total_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
362
363
    dataset->used_feature_map_.push_back(tmp_feature_map[i]);
  }
364
365
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int) *
                                            dataset->num_total_features_);
Guolin Ke's avatar
Guolin Ke committed
366
367
  // num_groups
  dataset->num_groups_ = *(reinterpret_cast<const int*>(mem_ptr));
368
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->num_groups_));
Guolin Ke's avatar
Guolin Ke committed
369
370
371
372
373
374
  // real_feature_idx_
  const int* tmp_ptr_real_feature_idx_ = reinterpret_cast<const int*>(mem_ptr);
  dataset->real_feature_idx_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->real_feature_idx_.push_back(tmp_ptr_real_feature_idx_[i]);
  }
375
376
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * dataset->num_features_);
Guolin Ke's avatar
Guolin Ke committed
377
378
379
380
381
382
  // feature2group
  const int* tmp_ptr_feature2group = reinterpret_cast<const int*>(mem_ptr);
  dataset->feature2group_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->feature2group_.push_back(tmp_ptr_feature2group[i]);
  }
383
384
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * dataset->num_features_);
Guolin Ke's avatar
Guolin Ke committed
385
386
387
388
389
390
  // feature2subfeature
  const int* tmp_ptr_feature2subfeature = reinterpret_cast<const int*>(mem_ptr);
  dataset->feature2subfeature_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->feature2subfeature_.push_back(tmp_ptr_feature2subfeature[i]);
  }
391
392
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * dataset->num_features_);
Guolin Ke's avatar
Guolin Ke committed
393
394
395
396
397
398
399
400
401
402
403
  // group_bin_boundaries
  const uint64_t* tmp_ptr_group_bin_boundaries = reinterpret_cast<const uint64_t*>(mem_ptr);
  dataset->group_bin_boundaries_.clear();
  for (int i = 0; i < dataset->num_groups_ + 1; ++i) {
    dataset->group_bin_boundaries_.push_back(tmp_ptr_group_bin_boundaries[i]);
  }
  mem_ptr += sizeof(uint64_t) * (dataset->num_groups_ + 1);

  // group_feature_start_
  const int* tmp_ptr_group_feature_start = reinterpret_cast<const int*>(mem_ptr);
  dataset->group_feature_start_.clear();
404
  for (int i = 0; i < dataset->num_groups_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
405
406
    dataset->group_feature_start_.push_back(tmp_ptr_group_feature_start[i]);
  }
407
408
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * (dataset->num_groups_));
Guolin Ke's avatar
Guolin Ke committed
409
410
411
412
413
414
415

  // group_feature_cnt_
  const int* tmp_ptr_group_feature_cnt = reinterpret_cast<const int*>(mem_ptr);
  dataset->group_feature_cnt_.clear();
  for (int i = 0; i < dataset->num_groups_; ++i) {
    dataset->group_feature_cnt_.push_back(tmp_ptr_group_feature_cnt[i]);
  }
416
417
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * (dataset->num_groups_));
Guolin Ke's avatar
Guolin Ke committed
418

Belinda Trotta's avatar
Belinda Trotta committed
419
  if (!config_.max_bin_by_feature.empty()) {
420
421
    CHECK_EQ(static_cast<size_t>(dataset->num_total_features_), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
Belinda Trotta's avatar
Belinda Trotta committed
422
423
424
425
426
427
428
429
430
    dataset->max_bin_by_feature_.resize(dataset->num_total_features_);
    dataset->max_bin_by_feature_.assign(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end());
  } else {
    const int32_t* tmp_ptr_max_bin_by_feature = reinterpret_cast<const int32_t*>(mem_ptr);
    dataset->max_bin_by_feature_.clear();
    for (int i = 0; i < dataset->num_total_features_; ++i) {
      dataset->max_bin_by_feature_.push_back(tmp_ptr_max_bin_by_feature[i]);
    }
  }
431
432
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int32_t) *
                                            (dataset->num_total_features_));
Belinda Trotta's avatar
Belinda Trotta committed
433
434
435
436
  if (ArrayArgs<int32_t>::CheckAll(dataset->max_bin_by_feature_, -1)) {
    dataset->max_bin_by_feature_.clear();
  }

Guolin Ke's avatar
Guolin Ke committed
437
  // get feature names
Guolin Ke's avatar
Guolin Ke committed
438
  dataset->feature_names_.clear();
Guolin Ke's avatar
Guolin Ke committed
439
440
441
  // write feature names
  for (int i = 0; i < dataset->num_total_features_; ++i) {
    int str_len = *(reinterpret_cast<const int*>(mem_ptr));
442
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int));
Guolin Ke's avatar
Guolin Ke committed
443
    std::stringstream str_buf;
444
    auto tmp_arr = reinterpret_cast<const char*>(mem_ptr);
Guolin Ke's avatar
Guolin Ke committed
445
    for (int j = 0; j < str_len; ++j) {
446
      char tmp_char = tmp_arr[j];
Guolin Ke's avatar
Guolin Ke committed
447
448
      str_buf << tmp_char;
    }
449
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(char) * str_len);
Guolin Ke's avatar
Guolin Ke committed
450
    dataset->feature_names_.emplace_back(str_buf.str());
Guolin Ke's avatar
Guolin Ke committed
451
  }
452
453
454
455
  // get forced_bin_bounds_
  dataset->forced_bin_bounds_ = std::vector<std::vector<double>>(dataset->num_total_features_, std::vector<double>());
  for (int i = 0; i < dataset->num_total_features_; ++i) {
    int num_bounds = *(reinterpret_cast<const int*>(mem_ptr));
456
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int));
457
    dataset->forced_bin_bounds_[i] = std::vector<double>();
458
459
    const double* tmp_ptr_forced_bounds =
        reinterpret_cast<const double*>(mem_ptr);
460
461
462
463
464
465
    for (int j = 0; j < num_bounds; ++j) {
      double bound = tmp_ptr_forced_bounds[j];
      dataset->forced_bin_bounds_[i].push_back(bound);
    }
    mem_ptr += num_bounds * sizeof(double);
  }
Guolin Ke's avatar
Guolin Ke committed
466
467

  // read size of meta data
468
  read_cnt = reader->Read(buffer.data(), sizeof(size_t));
Guolin Ke's avatar
Guolin Ke committed
469

470
  if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
471
472
473
    Log::Fatal("Binary file error: meta data has the wrong size");
  }

Guolin Ke's avatar
Guolin Ke committed
474
  size_t size_of_metadata = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
475
476
477
478

  // re-allocate space if not enough
  if (size_of_metadata > buffer_size) {
    buffer_size = size_of_metadata;
Guolin Ke's avatar
Guolin Ke committed
479
    buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
480
481
  }
  //  read meta data
482
  read_cnt = reader->Read(buffer.data(), size_of_metadata);
Guolin Ke's avatar
Guolin Ke committed
483
484
485
486
487

  if (read_cnt != size_of_metadata) {
    Log::Fatal("Binary file error: meta data is incorrect");
  }
  // load meta data
Guolin Ke's avatar
Guolin Ke committed
488
  dataset->metadata_.LoadFromMemory(buffer.data());
Guolin Ke's avatar
Guolin Ke committed
489

490
491
  *num_global_data = dataset->num_data_;
  used_data_indices->clear();
Guolin Ke's avatar
Guolin Ke committed
492
  // sample local used data if need to partition
Guolin Ke's avatar
Guolin Ke committed
493
  if (num_machines > 1 && !config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
494
495
496
497
    const data_size_t* query_boundaries = dataset->metadata_.query_boundaries();
    if (query_boundaries == nullptr) {
      // if not contain query file, minimal sample unit is one record
      for (data_size_t i = 0; i < dataset->num_data_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
498
        if (random_.NextShort(0, num_machines) == rank) {
499
          used_data_indices->push_back(i);
Guolin Ke's avatar
Guolin Ke committed
500
501
502
503
504
505
506
507
508
        }
      }
    } else {
      // if contain query file, minimal sample unit is one query
      data_size_t num_queries = dataset->metadata_.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      for (data_size_t i = 0; i < dataset->num_data_; ++i) {
        if (qid >= num_queries) {
509
510
          Log::Fatal("Current query exceeds the range of the query file,\n"
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
511
512
513
514
        }
        if (i >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
515
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
516
517
518
519
520
            is_query_used = true;
          }
          ++qid;
        }
        if (is_query_used) {
521
          used_data_indices->push_back(i);
Guolin Ke's avatar
Guolin Ke committed
522
523
524
        }
      }
    }
525
    dataset->num_data_ = static_cast<data_size_t>((*used_data_indices).size());
Guolin Ke's avatar
Guolin Ke committed
526
  }
527
  dataset->metadata_.PartitionLabel(*used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
528
  // read feature data
Guolin Ke's avatar
Guolin Ke committed
529
  for (int i = 0; i < dataset->num_groups_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
530
    // read feature size
531
532
    read_cnt = reader->Read(buffer.data(), sizeof(size_t));
    if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
533
534
      Log::Fatal("Binary file error: feature %d has the wrong size", i);
    }
Guolin Ke's avatar
Guolin Ke committed
535
    size_t size_of_feature = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
536
537
538
    // re-allocate space if not enough
    if (size_of_feature > buffer_size) {
      buffer_size = size_of_feature;
Guolin Ke's avatar
Guolin Ke committed
539
      buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
540
541
    }

542
    read_cnt = reader->Read(buffer.data(), size_of_feature);
Guolin Ke's avatar
Guolin Ke committed
543
544
545
546

    if (read_cnt != size_of_feature) {
      Log::Fatal("Binary file error: feature %d is incorrect, read count: %d", i, read_cnt);
    }
Guolin Ke's avatar
Guolin Ke committed
547
    dataset->feature_groups_.emplace_back(std::unique_ptr<FeatureGroup>(
548
549
      new FeatureGroup(buffer.data(),
                       *num_global_data,
550
                       *used_data_indices, i)));
Guolin Ke's avatar
Guolin Ke committed
551
  }
Guolin Ke's avatar
Guolin Ke committed
552
  dataset->feature_groups_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
553
  dataset->is_finish_load_ = true;
Guolin Ke's avatar
Guolin Ke committed
554
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
555
556
}

557

558
559
560
Dataset* DatasetLoader::ConstructFromSampleData(double** sample_values,
                                                int** sample_indices, int num_col, const int* num_per_col,
                                                size_t total_sample_size, data_size_t num_data) {
561
  CheckSampleSize(total_sample_size, static_cast<size_t>(num_data));
562
563
564
565
566
  int num_total_features = num_col;
  if (Network::num_machines() > 1) {
    num_total_features = Network::GlobalSyncUpByMax(num_total_features);
  }
  std::vector<std::unique_ptr<BinMapper>> bin_mappers(num_total_features);
567
568
  // fill feature_names_ if not header
  if (feature_names_.empty()) {
569
    for (int i = 0; i < num_col; ++i) {
570
571
572
573
574
      std::stringstream str_buf;
      str_buf << "Column_" << i;
      feature_names_.push_back(str_buf.str());
    }
  }
Belinda Trotta's avatar
Belinda Trotta committed
575
  if (!config_.max_bin_by_feature.empty()) {
576
577
    CHECK_EQ(static_cast<size_t>(num_col), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
Belinda Trotta's avatar
Belinda Trotta committed
578
  }
579
580
581
582
583

  // get forced split
  std::string forced_bins_path = config_.forcedbins_filename;
  std::vector<std::vector<double>> forced_bin_bounds = DatasetLoader::GetForcedBins(forced_bins_path, num_col, categorical_features_);

Guolin Ke's avatar
Guolin Ke committed
584
  const data_size_t filter_cnt = static_cast<data_size_t>(
Guolin Ke's avatar
Guolin Ke committed
585
    static_cast<double>(config_.min_data_in_leaf * total_sample_size) / num_data);
586
587
588
589
590
591
592
593
594
595
596
597
598
  if (Network::num_machines() == 1) {
    // if only one machine, find bin locally
    OMP_INIT_EX();
    #pragma omp parallel for schedule(guided)
    for (int i = 0; i < num_col; ++i) {
      OMP_LOOP_EX_BEGIN();
      if (ignore_features_.count(i) > 0) {
        bin_mappers[i] = nullptr;
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(i)) {
        bin_type = BinType::CategoricalBin;
599
600
601
602
        bool feat_is_unconstrained = ((config_.monotone_constraints.size() == 0) || (config_.monotone_constraints[i] == 0));
        if (!feat_is_unconstrained) {
            Log::Fatal("The output cannot be monotone with respect to categorical features");
        }
603
604
      }
      bin_mappers[i].reset(new BinMapper());
Belinda Trotta's avatar
Belinda Trotta committed
605
606
      if (config_.max_bin_by_feature.empty()) {
        bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
607
                                config_.max_bin, config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter,
608
609
                                bin_type, config_.use_missing, config_.zero_as_missing,
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
610
611
      } else {
        bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
612
                                config_.max_bin_by_feature[i], config_.min_data_in_bin,
613
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
614
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
615
      }
616
617
618
619
620
621
622
623
624
625
626
627
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();
  } else {
    // if have multi-machines, need to find bin distributed
    // different machines will find bin for different features
    int num_machines = Network::num_machines();
    int rank = Network::rank();
    // start and len will store the process feature indices for different machines
    // machine i will find bins for features in [ start[i], start[i] + len[i] )
    std::vector<int> start(num_machines);
    std::vector<int> len(num_machines);
628
    int step = (num_total_features + num_machines - 1) / num_machines;
629
630
631
632
    if (step < 1) { step = 1; }

    start[0] = 0;
    for (int i = 0; i < num_machines - 1; ++i) {
633
      len[i] = std::min(step, num_total_features - start[i]);
634
635
      start[i + 1] = start[i] + len[i];
    }
636
    len[num_machines - 1] = num_total_features - start[num_machines - 1];
637
638
639
640
641
642
643
644
645
646
647
648
    OMP_INIT_EX();
    #pragma omp parallel for schedule(guided)
    for (int i = 0; i < len[rank]; ++i) {
      OMP_LOOP_EX_BEGIN();
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(start[rank] + i)) {
        bin_type = BinType::CategoricalBin;
      }
      bin_mappers[i].reset(new BinMapper());
649
650
651
      if (num_col <= start[rank] + i) {
        continue;
      }
Belinda Trotta's avatar
Belinda Trotta committed
652
      if (config_.max_bin_by_feature.empty()) {
653
654
        bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
                                total_sample_size, config_.max_bin, config_.min_data_in_bin,
655
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
656
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
657
      } else {
658
659
        bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
                                total_sample_size, config_.max_bin_by_feature[start[rank] + i],
660
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
661
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
662
      }
663
664
      OMP_LOOP_EX_END();
    }
Guolin Ke's avatar
Guolin Ke committed
665
    OMP_THROW_EX();
666
    comm_size_t self_buf_size = 0;
667
    for (int i = 0; i < len[rank]; ++i) {
668
669
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
670
      }
671
      self_buf_size += static_cast<comm_size_t>(bin_mappers[i]->SizesInByte());
Guolin Ke's avatar
Guolin Ke committed
672
    }
673
674
    std::vector<char> input_buffer(self_buf_size);
    auto cp_ptr = input_buffer.data();
675
676
677
678
    for (int i = 0; i < len[rank]; ++i) {
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
679
680
      bin_mappers[i]->CopyTo(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
681
682
683
      // free
      bin_mappers[i].reset(nullptr);
    }
684
685
686
687
    std::vector<comm_size_t> size_len = Network::GlobalArray(self_buf_size);
    std::vector<comm_size_t> size_start(num_machines, 0);
    for (int i = 1; i < num_machines; ++i) {
      size_start[i] = size_start[i - 1] + size_len[i - 1];
688
    }
689
690
    comm_size_t total_buffer_size = size_start[num_machines - 1] + size_len[num_machines - 1];
    std::vector<char> output_buffer(total_buffer_size);
691
    // gather global feature bin mappers
692
693
    Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), total_buffer_size);
    cp_ptr = output_buffer.data();
694
    // restore features bins from buffer
695
    for (int i = 0; i < num_total_features; ++i) {
696
697
698
699
700
      if (ignore_features_.count(i) > 0) {
        bin_mappers[i] = nullptr;
        continue;
      }
      bin_mappers[i].reset(new BinMapper());
701
702
      bin_mappers[i]->CopyFrom(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
703
    }
Guolin Ke's avatar
Guolin Ke committed
704
  }
Guolin Ke's avatar
Guolin Ke committed
705
  auto dataset = std::unique_ptr<Dataset>(new Dataset(num_data));
Guolin Ke's avatar
Guolin Ke committed
706
  dataset->Construct(&bin_mappers, num_total_features, forced_bin_bounds, sample_indices, sample_values, num_per_col, num_col, total_sample_size, config_);
707
  dataset->set_feature_names(feature_names_);
Guolin Ke's avatar
Guolin Ke committed
708
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
709
}
Guolin Ke's avatar
Guolin Ke committed
710
711
712
713


// ---- private functions ----

714
void DatasetLoader::CheckDataset(const Dataset* dataset, bool is_load_from_binary) {
Guolin Ke's avatar
Guolin Ke committed
715
  if (dataset->num_data_ <= 0) {
Guolin Ke's avatar
Guolin Ke committed
716
    Log::Fatal("Data file %s is empty", dataset->data_filename_.c_str());
Guolin Ke's avatar
Guolin Ke committed
717
  }
718
719
  if (dataset->feature_names_.size() != static_cast<size_t>(dataset->num_total_features_)) {
    Log::Fatal("Size of feature name error, should be %d, got %d", dataset->num_total_features_,
720
               static_cast<int>(dataset->feature_names_.size()));
721
  }
Guolin Ke's avatar
Guolin Ke committed
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
  bool is_feature_order_by_group = true;
  int last_group = -1;
  int last_sub_feature = -1;
  // if features are ordered, not need to use hist_buf
  for (int i = 0; i < dataset->num_features_; ++i) {
    int group = dataset->feature2group_[i];
    int sub_feature = dataset->feature2subfeature_[i];
    if (group < last_group) {
      is_feature_order_by_group = false;
    } else if (group == last_group) {
      if (sub_feature <= last_sub_feature) {
        is_feature_order_by_group = false;
        break;
      }
    }
    last_group = group;
    last_sub_feature = sub_feature;
  }
  if (!is_feature_order_by_group) {
741
    Log::Fatal("Features in dataset should be ordered by group");
Guolin Ke's avatar
Guolin Ke committed
742
  }
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768

  if (is_load_from_binary) {
    if (dataset->max_bin_ != config_.max_bin) {
      Log::Fatal("Dataset max_bin %d != config %d", dataset->max_bin_, config_.max_bin);
    }
    if (dataset->min_data_in_bin_ != config_.min_data_in_bin) {
      Log::Fatal("Dataset min_data_in_bin %d != config %d", dataset->min_data_in_bin_, config_.min_data_in_bin);
    }
    if (dataset->use_missing_ != config_.use_missing) {
      Log::Fatal("Dataset use_missing %d != config %d", dataset->use_missing_, config_.use_missing);
    }
    if (dataset->zero_as_missing_ != config_.zero_as_missing) {
      Log::Fatal("Dataset zero_as_missing %d != config %d", dataset->zero_as_missing_, config_.zero_as_missing);
    }
    if (dataset->bin_construct_sample_cnt_ != config_.bin_construct_sample_cnt) {
      Log::Fatal("Dataset bin_construct_sample_cnt %d != config %d", dataset->bin_construct_sample_cnt_, config_.bin_construct_sample_cnt);
    }
    if ((dataset->max_bin_by_feature_.size() != config_.max_bin_by_feature.size()) ||
        !std::equal(dataset->max_bin_by_feature_.begin(), dataset->max_bin_by_feature_.end(),
            config_.max_bin_by_feature.begin())) {
      Log::Fatal("Dataset max_bin_by_feature does not match with config");
    }

    int label_idx = -1;
    if (Common::AtoiAndCheck(config_.label_column.c_str(), &label_idx)) {
      if (dataset->label_idx_ != label_idx) {
769
        Log::Fatal("Dataset label_idx %d != config %d", dataset->label_idx_, label_idx);
770
771
772
773
774
      }
    } else {
      Log::Info("Recommend use integer for label index when loading data from binary for sanity check.");
    }
  }
Guolin Ke's avatar
Guolin Ke committed
775
776
777
}

std::vector<std::string> DatasetLoader::LoadTextDataToMemory(const char* filename, const Metadata& metadata,
778
779
                                                             int rank, int num_machines, int* num_global_data,
                                                             std::vector<data_size_t>* used_data_indices) {
780
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
781
  used_data_indices->clear();
Guolin Ke's avatar
Guolin Ke committed
782
  if (num_machines == 1 || config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
783
784
785
786
787
788
789
790
791
    // read all lines
    *num_global_data = text_reader.ReadAllLines();
  } else {  // need partition data
            // get query data
    const data_size_t* query_boundaries = metadata.query_boundaries();

    if (query_boundaries == nullptr) {
      // if not contain query data, minimal sample unit is one record
      *num_global_data = text_reader.ReadAndFilterLines([this, rank, num_machines](data_size_t) {
Guolin Ke's avatar
Guolin Ke committed
792
        if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
793
794
795
796
797
798
799
800
801
802
803
804
805
806
          return true;
        } else {
          return false;
        }
      }, used_data_indices);
    } else {
      // if contain query data, minimal sample unit is one query
      data_size_t num_queries = metadata.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      *num_global_data = text_reader.ReadAndFilterLines(
        [this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries]
      (data_size_t line_idx) {
        if (qid >= num_queries) {
807
808
          Log::Fatal("Current query exceeds the range of the query file,\n"
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
809
810
811
812
        }
        if (line_idx >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
813
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
814
815
816
817
818
819
820
821
822
823
824
825
            is_query_used = true;
          }
          ++qid;
        }
        return is_query_used;
      }, used_data_indices);
    }
  }
  return std::move(text_reader.Lines());
}

std::vector<std::string> DatasetLoader::SampleTextDataFromMemory(const std::vector<std::string>& data) {
Guolin Ke's avatar
Guolin Ke committed
826
  int sample_cnt = config_.bin_construct_sample_cnt;
827
828
  if (static_cast<size_t>(sample_cnt) > data.size()) {
    sample_cnt = static_cast<int>(data.size());
829
  }
830
  auto sample_indices = random_.Sample(static_cast<int>(data.size()), sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
831
  std::vector<std::string> out(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
832
833
  for (size_t i = 0; i < sample_indices.size(); ++i) {
    const size_t idx = sample_indices[i];
Guolin Ke's avatar
Guolin Ke committed
834
    out[i] = data[idx];
Guolin Ke's avatar
Guolin Ke committed
835
836
837
838
  }
  return out;
}

839
840
841
std::vector<std::string> DatasetLoader::SampleTextDataFromFile(const char* filename, const Metadata& metadata,
                                                               int rank, int num_machines, int* num_global_data,
                                                               std::vector<data_size_t>* used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
842
  const data_size_t sample_cnt = static_cast<data_size_t>(config_.bin_construct_sample_cnt);
843
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
844
  std::vector<std::string> out_data;
Guolin Ke's avatar
Guolin Ke committed
845
  if (num_machines == 1 || config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
846
    *num_global_data = static_cast<data_size_t>(text_reader.SampleFromFile(&random_, sample_cnt, &out_data));
Guolin Ke's avatar
Guolin Ke committed
847
848
849
850
851
852
853
  } else {  // need partition data
            // get query data
    const data_size_t* query_boundaries = metadata.query_boundaries();
    if (query_boundaries == nullptr) {
      // if not contain query file, minimal sample unit is one record
      *num_global_data = text_reader.SampleAndFilterFromFile([this, rank, num_machines]
      (data_size_t) {
Guolin Ke's avatar
Guolin Ke committed
854
        if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
855
856
857
858
          return true;
        } else {
          return false;
        }
Guolin Ke's avatar
Guolin Ke committed
859
      }, used_data_indices, &random_, sample_cnt, &out_data);
Guolin Ke's avatar
Guolin Ke committed
860
861
862
863
864
865
866
867
868
    } else {
      // if contain query file, minimal sample unit is one query
      data_size_t num_queries = metadata.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      *num_global_data = text_reader.SampleAndFilterFromFile(
        [this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries]
      (data_size_t line_idx) {
        if (qid >= num_queries) {
869
870
          Log::Fatal("Query id exceeds the range of the query file, "
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
871
872
873
874
        }
        if (line_idx >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
875
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
876
877
878
879
880
            is_query_used = true;
          }
          ++qid;
        }
        return is_query_used;
Guolin Ke's avatar
Guolin Ke committed
881
      }, used_data_indices, &random_, sample_cnt, &out_data);
Guolin Ke's avatar
Guolin Ke committed
882
883
884
885
886
    }
  }
  return out_data;
}

887
888
889
void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
                                                    const std::vector<std::string>& sample_data,
                                                    const Parser* parser, Dataset* dataset) {
Guolin Ke's avatar
Guolin Ke committed
890
  std::vector<std::vector<double>> sample_values;
Guolin Ke's avatar
Guolin Ke committed
891
  std::vector<std::vector<int>> sample_indices;
Guolin Ke's avatar
Guolin Ke committed
892
893
  std::vector<std::pair<int, double>> oneline_features;
  double label;
Guolin Ke's avatar
Guolin Ke committed
894
  for (int i = 0; i < static_cast<int>(sample_data.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
895
896
897
898
    oneline_features.clear();
    // parse features
    parser->ParseOneLine(sample_data[i].c_str(), &oneline_features, &label);
    for (std::pair<int, double>& inner_data : oneline_features) {
899
      if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
Guolin Ke's avatar
Guolin Ke committed
900
901
        sample_values.resize(inner_data.first + 1);
        sample_indices.resize(inner_data.first + 1);
902
      }
Guolin Ke's avatar
Guolin Ke committed
903
      if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
Guolin Ke's avatar
Guolin Ke committed
904
905
        sample_values[inner_data.first].emplace_back(inner_data.second);
        sample_indices[inner_data.first].emplace_back(i);
Guolin Ke's avatar
Guolin Ke committed
906
907
908
909
      }
    }
  }

Guolin Ke's avatar
Guolin Ke committed
910
  dataset->feature_groups_.clear();
911
912
913
914
915
  dataset->num_total_features_ = std::max(static_cast<int>(sample_values.size()), parser->NumFeatures());
  if (num_machines > 1) {
    dataset->num_total_features_ = Network::GlobalSyncUpByMax(dataset->num_total_features_);
  }
  if (!feature_names_.empty()) {
916
    CHECK_EQ(dataset->num_total_features_, static_cast<int>(feature_names_.size()));
917
  }
Guolin Ke's avatar
Guolin Ke committed
918

Belinda Trotta's avatar
Belinda Trotta committed
919
  if (!config_.max_bin_by_feature.empty()) {
920
921
    CHECK_EQ(static_cast<size_t>(dataset->num_total_features_), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
Belinda Trotta's avatar
Belinda Trotta committed
922
923
  }

924
925
  // get forced split
  std::string forced_bins_path = config_.forcedbins_filename;
926
927
  std::vector<std::vector<double>> forced_bin_bounds = DatasetLoader::GetForcedBins(forced_bins_path,
                                                                                    dataset->num_total_features_,
928
929
                                                                                    categorical_features_);

Guolin Ke's avatar
Guolin Ke committed
930
931
932
933
934
935
  // check the range of label_idx, weight_idx and group_idx
  CHECK(label_idx_ >= 0 && label_idx_ <= dataset->num_total_features_);
  CHECK(weight_idx_ < 0 || weight_idx_ < dataset->num_total_features_);
  CHECK(group_idx_ < 0 || group_idx_ < dataset->num_total_features_);

  // fill feature_names_ if not header
Guolin Ke's avatar
Guolin Ke committed
936
  if (feature_names_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
937
938
939
940
941
942
    for (int i = 0; i < dataset->num_total_features_; ++i) {
      std::stringstream str_buf;
      str_buf << "Column_" << i;
      feature_names_.push_back(str_buf.str());
    }
  }
943
  dataset->set_feature_names(feature_names_);
Guolin Ke's avatar
Guolin Ke committed
944
  std::vector<std::unique_ptr<BinMapper>> bin_mappers(dataset->num_total_features_);
Guolin Ke's avatar
Guolin Ke committed
945
  const data_size_t filter_cnt = static_cast<data_size_t>(
Guolin Ke's avatar
Guolin Ke committed
946
    static_cast<double>(config_.min_data_in_leaf* sample_data.size()) / dataset->num_data_);
Guolin Ke's avatar
Guolin Ke committed
947
948
949
  // start find bins
  if (num_machines == 1) {
    // if only one machine, find bin locally
950
    OMP_INIT_EX();
951
    #pragma omp parallel for schedule(guided)
Guolin Ke's avatar
Guolin Ke committed
952
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
953
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
954
      if (ignore_features_.count(i) > 0) {
Guolin Ke's avatar
Guolin Ke committed
955
        bin_mappers[i] = nullptr;
Guolin Ke's avatar
Guolin Ke committed
956
957
        continue;
      }
958
959
960
961
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(i)) {
        bin_type = BinType::CategoricalBin;
      }
Guolin Ke's avatar
Guolin Ke committed
962
      bin_mappers[i].reset(new BinMapper());
Belinda Trotta's avatar
Belinda Trotta committed
963
964
      if (config_.max_bin_by_feature.empty()) {
        bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
965
                                sample_data.size(), config_.max_bin, config_.min_data_in_bin,
966
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
967
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
968
969
      } else {
        bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
970
                                sample_data.size(), config_.max_bin_by_feature[i],
971
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
972
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
973
      }
974
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
975
    }
976
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
977
978
  } else {
    // start and len will store the process feature indices for different machines
979
    // machine i will find bins for features in [ start[i], start[i] + len[i] )
Guolin Ke's avatar
Guolin Ke committed
980
981
    std::vector<int> start(num_machines);
    std::vector<int> len(num_machines);
982
    int step = (dataset->num_total_features_ + num_machines - 1) / num_machines;
Guolin Ke's avatar
Guolin Ke committed
983
984
985
986
    if (step < 1) { step = 1; }

    start[0] = 0;
    for (int i = 0; i < num_machines - 1; ++i) {
987
      len[i] = std::min(step, dataset->num_total_features_ - start[i]);
Guolin Ke's avatar
Guolin Ke committed
988
989
      start[i + 1] = start[i] + len[i];
    }
990
    len[num_machines - 1] = dataset->num_total_features_ - start[num_machines - 1];
991
    OMP_INIT_EX();
992
    #pragma omp parallel for schedule(guided)
993
    for (int i = 0; i < len[rank]; ++i) {
994
      OMP_LOOP_EX_BEGIN();
995
996
997
998
999
1000
1001
1002
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(start[rank] + i)) {
        bin_type = BinType::CategoricalBin;
      }
      bin_mappers[i].reset(new BinMapper());
Nikita Titov's avatar
Nikita Titov committed
1003
      if (static_cast<int>(sample_values.size()) <= start[rank] + i) {
1004
1005
        continue;
      }
Belinda Trotta's avatar
Belinda Trotta committed
1006
      if (config_.max_bin_by_feature.empty()) {
1007
        bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
Belinda Trotta's avatar
Belinda Trotta committed
1008
                                static_cast<int>(sample_values[start[rank] + i].size()),
1009
                                sample_data.size(), config_.max_bin, config_.min_data_in_bin,
1010
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
1011
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
1012
      } else {
1013
        bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
Belinda Trotta's avatar
Belinda Trotta committed
1014
                                static_cast<int>(sample_values[start[rank] + i].size()),
1015
                                sample_data.size(), config_.max_bin_by_feature[i],
1016
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type,
1017
                                config_.use_missing, config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
1018
      }
1019
      OMP_LOOP_EX_END();
1020
    }
1021
    OMP_THROW_EX();
1022
    comm_size_t self_buf_size = 0;
Guolin Ke's avatar
Guolin Ke committed
1023
    for (int i = 0; i < len[rank]; ++i) {
1024
1025
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
Guolin Ke's avatar
Guolin Ke committed
1026
      }
1027
      self_buf_size += static_cast<comm_size_t>(bin_mappers[i]->SizesInByte());
Guolin Ke's avatar
Guolin Ke committed
1028
    }
1029
1030
    std::vector<char> input_buffer(self_buf_size);
    auto cp_ptr = input_buffer.data();
Guolin Ke's avatar
Guolin Ke committed
1031
    for (int i = 0; i < len[rank]; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1032
1033
1034
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
1035
1036
      bin_mappers[i]->CopyTo(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
1037
1038
      // free
      bin_mappers[i].reset(nullptr);
Guolin Ke's avatar
Guolin Ke committed
1039
    }
1040
1041
1042
1043
    std::vector<comm_size_t> size_len = Network::GlobalArray(self_buf_size);
    std::vector<comm_size_t> size_start(num_machines, 0);
    for (int i = 1; i < num_machines; ++i) {
      size_start[i] = size_start[i - 1] + size_len[i - 1];
Guolin Ke's avatar
Guolin Ke committed
1044
    }
1045
1046
    comm_size_t total_buffer_size = size_start[num_machines - 1] + size_len[num_machines - 1];
    std::vector<char> output_buffer(total_buffer_size);
Guolin Ke's avatar
Guolin Ke committed
1047
    // gather global feature bin mappers
1048
1049
    Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), total_buffer_size);
    cp_ptr = output_buffer.data();
Guolin Ke's avatar
Guolin Ke committed
1050
    // restore features bins from buffer
1051
    for (int i = 0; i < dataset->num_total_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1052
      if (ignore_features_.count(i) > 0) {
Guolin Ke's avatar
Guolin Ke committed
1053
        bin_mappers[i] = nullptr;
Guolin Ke's avatar
Guolin Ke committed
1054
1055
        continue;
      }
Guolin Ke's avatar
Guolin Ke committed
1056
      bin_mappers[i].reset(new BinMapper());
1057
1058
      bin_mappers[i]->CopyFrom(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
Guolin Ke's avatar
Guolin Ke committed
1059
1060
    }
  }
1061
  dataset->Construct(&bin_mappers, dataset->num_total_features_, forced_bin_bounds, Common::Vector2Ptr<int>(&sample_indices).data(),
Guolin Ke's avatar
Guolin Ke committed
1062
                     Common::Vector2Ptr<double>(&sample_values).data(),
1063
                     Common::VectorSize<int>(sample_indices).data(), static_cast<int>(sample_indices.size()), sample_data.size(), config_);
Guolin Ke's avatar
Guolin Ke committed
1064
1065
1066
}

/*! \brief Extract local features from memory */
Guolin Ke's avatar
Guolin Ke committed
1067
void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>* text_data, const Parser* parser, Dataset* dataset) {
Guolin Ke's avatar
Guolin Ke committed
1068
1069
  std::vector<std::pair<int, double>> oneline_features;
  double tmp_label = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
1070
  auto& ref_text_data = *text_data;
Guolin Ke's avatar
Guolin Ke committed
1071
  if (predict_fun_ == nullptr) {
1072
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1073
    // if doesn't need to prediction with initial model
1074
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
Guolin Ke's avatar
Guolin Ke committed
1075
    for (data_size_t i = 0; i < dataset->num_data_; ++i) {
1076
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1077
1078
1079
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
Guolin Ke's avatar
Guolin Ke committed
1080
      parser->ParseOneLine(ref_text_data[i].c_str(), &oneline_features, &tmp_label);
Guolin Ke's avatar
Guolin Ke committed
1081
      // set label
1082
      dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1083
      // free processed line:
Guolin Ke's avatar
Guolin Ke committed
1084
      ref_text_data[i].clear();
Guolin Ke's avatar
Guolin Ke committed
1085
1086
      // shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
      // text_reader_->Lines()[i].shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
1087
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1088
1089
      // push data
      for (auto& inner_data : oneline_features) {
1090
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1091
1092
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1093
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1094
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1095
1096
1097
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
          dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
Guolin Ke's avatar
Guolin Ke committed
1098
1099
        } else {
          if (inner_data.first == weight_idx_) {
1100
            dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1101
1102
1103
1104
1105
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1106
      dataset->FinishOneRow(tid, i, is_feature_added);
1107
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1108
    }
1109
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1110
  } else {
1111
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1112
    // if need to prediction with initial model
1113
    std::vector<double> init_score(dataset->num_data_ * num_class_);
1114
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
Guolin Ke's avatar
Guolin Ke committed
1115
    for (data_size_t i = 0; i < dataset->num_data_; ++i) {
1116
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1117
1118
1119
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
Guolin Ke's avatar
Guolin Ke committed
1120
      parser->ParseOneLine(ref_text_data[i].c_str(), &oneline_features, &tmp_label);
Guolin Ke's avatar
Guolin Ke committed
1121
      // set initial score
Guolin Ke's avatar
Guolin Ke committed
1122
1123
      std::vector<double> oneline_init_score(num_class_);
      predict_fun_(oneline_features, oneline_init_score.data());
1124
      for (int k = 0; k < num_class_; ++k) {
1125
        init_score[k * dataset->num_data_ + i] = static_cast<double>(oneline_init_score[k]);
Guolin Ke's avatar
Guolin Ke committed
1126
1127
      }
      // set label
1128
      dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1129
1130
1131
1132
1133
      // free processed line:
      text_data[i].clear();
      // shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
      // text_reader_->Lines()[i].shrink_to_fit();
      // push data
Guolin Ke's avatar
Guolin Ke committed
1134
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1135
      for (auto& inner_data : oneline_features) {
1136
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1137
1138
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1139
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1140
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1141
1142
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
1143
          dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
Guolin Ke's avatar
Guolin Ke committed
1144
1145
        } else {
          if (inner_data.first == weight_idx_) {
1146
            dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1147
1148
1149
1150
1151
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1152
      dataset->FinishOneRow(tid, i, is_feature_added);
1153
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1154
    }
1155
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1156
    // metadata_ will manage space of init_score
1157
    dataset->metadata_.SetInitScore(init_score.data(), dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1158
  }
Guolin Ke's avatar
Guolin Ke committed
1159
  dataset->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1160
  // text data can be free after loaded feature values
Guolin Ke's avatar
Guolin Ke committed
1161
  text_data->clear();
Guolin Ke's avatar
Guolin Ke committed
1162
1163
1164
}

/*! \brief Extract local features from file */
1165
1166
void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser* parser,
                                            const std::vector<data_size_t>& used_data_indices, Dataset* dataset) {
1167
  std::vector<double> init_score;
Guolin Ke's avatar
Guolin Ke committed
1168
  if (predict_fun_ != nullptr) {
1169
    init_score = std::vector<double>(dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1170
1171
1172
1173
1174
1175
  }
  std::function<void(data_size_t, const std::vector<std::string>&)> process_fun =
    [this, &init_score, &parser, &dataset]
  (data_size_t start_idx, const std::vector<std::string>& lines) {
    std::vector<std::pair<int, double>> oneline_features;
    double tmp_label = 0.0f;
1176
    OMP_INIT_EX();
1177
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
Guolin Ke's avatar
Guolin Ke committed
1178
    for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) {
1179
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1180
1181
1182
1183
1184
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
      parser->ParseOneLine(lines[i].c_str(), &oneline_features, &tmp_label);
      // set initial score
Guolin Ke's avatar
Guolin Ke committed
1185
      if (!init_score.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1186
1187
        std::vector<double> oneline_init_score(num_class_);
        predict_fun_(oneline_features, oneline_init_score.data());
1188
        for (int k = 0; k < num_class_; ++k) {
1189
          init_score[k * dataset->num_data_ + start_idx + i] = static_cast<double>(oneline_init_score[k]);
Guolin Ke's avatar
Guolin Ke committed
1190
1191
1192
        }
      }
      // set label
1193
      dataset->metadata_.SetLabelAt(start_idx + i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1194
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1195
1196
      // push data
      for (auto& inner_data : oneline_features) {
1197
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1198
1199
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1200
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1201
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1202
1203
1204
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
          dataset->feature_groups_[group]->PushData(tid, sub_feature, start_idx + i, inner_data.second);
Guolin Ke's avatar
Guolin Ke committed
1205
1206
        } else {
          if (inner_data.first == weight_idx_) {
1207
            dataset->metadata_.SetWeightAt(start_idx + i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1208
1209
1210
1211
1212
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(start_idx + i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1213
      dataset->FinishOneRow(tid, i, is_feature_added);
1214
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1215
    }
1216
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1217
  };
1218
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
1219
  if (!used_data_indices.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1220
1221
1222
1223
1224
1225
1226
1227
    // only need part of data
    text_reader.ReadPartAndProcessParallel(used_data_indices, process_fun);
  } else {
    // need full data
    text_reader.ReadAllAndProcessParallel(process_fun);
  }

  // metadata_ will manage space of init_score
Guolin Ke's avatar
Guolin Ke committed
1228
  if (!init_score.empty()) {
1229
    dataset->metadata_.SetInitScore(init_score.data(), dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1230
  }
Guolin Ke's avatar
Guolin Ke committed
1231
  dataset->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1232
1233
1234
}

/*! \brief Check can load from binary file */
1235
std::string DatasetLoader::CheckCanLoadFromBin(const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
1236
1237
1238
  std::string bin_filename(filename);
  bin_filename.append(".bin");

1239
  auto reader = VirtualFileReader::Make(bin_filename.c_str());
Guolin Ke's avatar
Guolin Ke committed
1240

1241
  if (!reader->Init()) {
1242
    bin_filename = std::string(filename);
1243
1244
    reader = VirtualFileReader::Make(bin_filename.c_str());
    if (!reader->Init()) {
1245
      Log::Fatal("Cannot open data file %s", bin_filename.c_str());
1246
    }
1247
  }
1248
1249
1250
1251
1252

  size_t buffer_size = 256;
  auto buffer = std::vector<char>(buffer_size);
  // read size of token
  size_t size_of_token = std::strlen(Dataset::binary_file_token);
1253
  size_t read_cnt = reader->Read(buffer.data(), size_of_token);
1254
1255
  if (read_cnt == size_of_token
      && std::string(buffer.data()) == std::string(Dataset::binary_file_token)) {
1256
    return bin_filename;
Guolin Ke's avatar
Guolin Ke committed
1257
  } else {
1258
    return std::string();
Guolin Ke's avatar
Guolin Ke committed
1259
1260
1261
  }
}

1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274


std::vector<std::vector<double>> DatasetLoader::GetForcedBins(std::string forced_bins_path, int num_total_features,
                                                              const std::unordered_set<int>& categorical_features) {
  std::vector<std::vector<double>> forced_bins(num_total_features, std::vector<double>());
  if (forced_bins_path != "") {
    std::ifstream forced_bins_stream(forced_bins_path.c_str());
    if (forced_bins_stream.fail()) {
      Log::Warning("Could not open %s. Will ignore.", forced_bins_path.c_str());
    } else {
      std::stringstream buffer;
      buffer << forced_bins_stream.rdbuf();
      std::string err;
Guolin Ke's avatar
Guolin Ke committed
1275
      Json forced_bins_json = Json::parse(buffer.str(), &err);
1276
1277
1278
1279
      CHECK(forced_bins_json.is_array());
      std::vector<Json> forced_bins_arr = forced_bins_json.array_items();
      for (size_t i = 0; i < forced_bins_arr.size(); ++i) {
        int feature_num = forced_bins_arr[i]["feature"].int_value();
Nikita Titov's avatar
Nikita Titov committed
1280
        CHECK_LT(feature_num, num_total_features);
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
        if (categorical_features.count(feature_num)) {
          Log::Warning("Feature %d is categorical. Will ignore forced bins for this  feature.", feature_num);
        } else {
          std::vector<Json> bounds_arr = forced_bins_arr[i]["bin_upper_bound"].array_items();
          for (size_t j = 0; j < bounds_arr.size(); ++j) {
            forced_bins[feature_num].push_back(bounds_arr[j].number_value());
          }
        }
      }
      // remove duplicates
      for (int i = 0; i < num_total_features; ++i) {
        auto new_end = std::unique(forced_bins[i].begin(), forced_bins[i].end());
        forced_bins[i].erase(new_end, forced_bins[i].end());
      }
    }
  }
  return forced_bins;
}

1300
}  // namespace LightGBM