dataset_loader.cpp 55.2 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
#include <LightGBM/dataset_loader.h>

Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/network.h>
8
#include <LightGBM/utils/array_args.h>
9
#include <LightGBM/utils/json11.h>
10
11
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
12

13
14
#include <fstream>

Guolin Ke's avatar
Guolin Ke committed
15
16
namespace LightGBM {

17
18
using json11::Json;

Guolin Ke's avatar
Guolin Ke committed
19
20
DatasetLoader::DatasetLoader(const Config& io_config, const PredictFunction& predict_fun, int num_class, const char* filename)
  :config_(io_config), random_(config_.data_random_seed), predict_fun_(predict_fun), num_class_(num_class) {
Guolin Ke's avatar
Guolin Ke committed
21
22
23
24
  label_idx_ = 0;
  weight_idx_ = NO_SPECIFIC;
  group_idx_ = NO_SPECIFIC;
  SetHeader(filename);
Guolin Ke's avatar
Guolin Ke committed
25
26
27
28
29
}

DatasetLoader::~DatasetLoader() {
}

Guolin Ke's avatar
Guolin Ke committed
30
void DatasetLoader::SetHeader(const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
31
  std::unordered_map<std::string, int> name2idx;
Guolin Ke's avatar
Guolin Ke committed
32
33
  std::string name_prefix("name:");
  if (filename != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
34
    TextReader<data_size_t> text_reader(filename, config_.header);
Guolin Ke's avatar
Guolin Ke committed
35

Guolin Ke's avatar
Guolin Ke committed
36
    // get column names
Guolin Ke's avatar
Guolin Ke committed
37
    if (config_.header) {
Guolin Ke's avatar
Guolin Ke committed
38
      std::string first_line = text_reader.first_line();
39
      feature_names_ = Common::Split(first_line.c_str(), "\t,");
Guolin Ke's avatar
Guolin Ke committed
40
41
    }

Guolin Ke's avatar
Guolin Ke committed
42
    // load label idx first
Guolin Ke's avatar
Guolin Ke committed
43
44
45
    if (config_.label_column.size() > 0) {
      if (Common::StartsWith(config_.label_column, name_prefix)) {
        std::string name = config_.label_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
46
47
48
49
50
51
52
53
54
55
        label_idx_ = -1;
        for (int i = 0; i < static_cast<int>(feature_names_.size()); ++i) {
          if (name == feature_names_[i]) {
            label_idx_ = i;
            break;
          }
        }
        if (label_idx_ >= 0) {
          Log::Info("Using column %s as label", name.c_str());
        } else {
56
57
          Log::Fatal("Could not find label column %s in data file \n"
                     "or data file doesn't contain header", name.c_str());
Guolin Ke's avatar
Guolin Ke committed
58
        }
Guolin Ke's avatar
Guolin Ke committed
59
      } else {
Guolin Ke's avatar
Guolin Ke committed
60
        if (!Common::AtoiAndCheck(config_.label_column.c_str(), &label_idx_)) {
61
62
63
          Log::Fatal("label_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
64
65
        }
        Log::Info("Using column number %d as label", label_idx_);
Guolin Ke's avatar
Guolin Ke committed
66
67
      }
    }
Guolin Ke's avatar
Guolin Ke committed
68

Guolin Ke's avatar
Guolin Ke committed
69
    if (!feature_names_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
70
71
72
73
      // erase label column name
      feature_names_.erase(feature_names_.begin() + label_idx_);
      for (size_t i = 0; i < feature_names_.size(); ++i) {
        name2idx[feature_names_[i]] = static_cast<int>(i);
Guolin Ke's avatar
Guolin Ke committed
74
      }
Guolin Ke's avatar
Guolin Ke committed
75
76
77
    }

    // load ignore columns
Guolin Ke's avatar
Guolin Ke committed
78
79
80
    if (config_.ignore_column.size() > 0) {
      if (Common::StartsWith(config_.ignore_column, name_prefix)) {
        std::string names = config_.ignore_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
81
82
83
84
85
86
87
88
89
        for (auto name : Common::Split(names.c_str(), ',')) {
          if (name2idx.count(name) > 0) {
            int tmp = name2idx[name];
            ignore_features_.emplace(tmp);
          } else {
            Log::Fatal("Could not find ignore column %s in data file", name.c_str());
          }
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
90
        for (auto token : Common::Split(config_.ignore_column.c_str(), ',')) {
Guolin Ke's avatar
Guolin Ke committed
91
92
          int tmp = 0;
          if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
93
94
95
            Log::Fatal("ignore_column is not a number,\n"
                       "if you want to use a column name,\n"
                       "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
96
97
          }
          ignore_features_.emplace(tmp);
Guolin Ke's avatar
Guolin Ke committed
98
99
100
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
101
    // load weight idx
Guolin Ke's avatar
Guolin Ke committed
102
103
104
    if (config_.weight_column.size() > 0) {
      if (Common::StartsWith(config_.weight_column, name_prefix)) {
        std::string name = config_.weight_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
105
106
107
108
109
110
        if (name2idx.count(name) > 0) {
          weight_idx_ = name2idx[name];
          Log::Info("Using column %s as weight", name.c_str());
        } else {
          Log::Fatal("Could not find weight column %s in data file", name.c_str());
        }
Guolin Ke's avatar
Guolin Ke committed
111
      } else {
Guolin Ke's avatar
Guolin Ke committed
112
        if (!Common::AtoiAndCheck(config_.weight_column.c_str(), &weight_idx_)) {
113
114
115
          Log::Fatal("weight_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
116
117
        }
        Log::Info("Using column number %d as weight", weight_idx_);
Guolin Ke's avatar
Guolin Ke committed
118
      }
Guolin Ke's avatar
Guolin Ke committed
119
      ignore_features_.emplace(weight_idx_);
Guolin Ke's avatar
Guolin Ke committed
120
    }
Guolin Ke's avatar
Guolin Ke committed
121
    // load group idx
Guolin Ke's avatar
Guolin Ke committed
122
123
124
    if (config_.group_column.size() > 0) {
      if (Common::StartsWith(config_.group_column, name_prefix)) {
        std::string name = config_.group_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
125
126
127
128
129
130
131
        if (name2idx.count(name) > 0) {
          group_idx_ = name2idx[name];
          Log::Info("Using column %s as group/query id", name.c_str());
        } else {
          Log::Fatal("Could not find group/query column %s in data file", name.c_str());
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
132
        if (!Common::AtoiAndCheck(config_.group_column.c_str(), &group_idx_)) {
133
134
135
          Log::Fatal("group_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
136
137
138
139
        }
        Log::Info("Using column number %d as group/query id", group_idx_);
      }
      ignore_features_.emplace(group_idx_);
Guolin Ke's avatar
Guolin Ke committed
140
141
    }
  }
Guolin Ke's avatar
Guolin Ke committed
142
143
144
  if (config_.categorical_feature.size() > 0) {
    if (Common::StartsWith(config_.categorical_feature, name_prefix)) {
      std::string names = config_.categorical_feature.substr(name_prefix.size());
145
146
147
148
149
      for (auto name : Common::Split(names.c_str(), ',')) {
        if (name2idx.count(name) > 0) {
          int tmp = name2idx[name];
          categorical_features_.emplace(tmp);
        } else {
Guolin Ke's avatar
Guolin Ke committed
150
          Log::Fatal("Could not find categorical_feature %s in data file", name.c_str());
151
152
153
        }
      }
    } else {
Guolin Ke's avatar
Guolin Ke committed
154
      for (auto token : Common::Split(config_.categorical_feature.c_str(), ',')) {
155
156
        int tmp = 0;
        if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
Guolin Ke's avatar
Guolin Ke committed
157
          Log::Fatal("categorical_feature is not a number,\n"
158
159
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
160
161
162
163
164
        }
        categorical_features_.emplace(tmp);
      }
    }
  }
Guolin Ke's avatar
Guolin Ke committed
165
166
}

167
168
169
170
171
172
173
174
175
176
void CheckSampleSize(size_t sample_cnt, size_t num_data) {
  if (static_cast<double>(sample_cnt) / num_data < 0.2f &&
      sample_cnt < 100000) {
    Log::Warning(
        "Using too small ``bin_construct_sample_cnt`` may encounter "
        "unexpected "
        "errors and poor accuracy.");
  }
}

177
Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_machines) {
Guolin Ke's avatar
Guolin Ke committed
178
  // don't support query id in data file when training in parallel
Guolin Ke's avatar
Guolin Ke committed
179
  if (num_machines > 1 && !config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
180
    if (group_idx_ > 0) {
181
182
      Log::Fatal("Using a query id without pre-partitioning the data file is not supported for parallel training.\n"
                 "Please use an additional query file or pre-partition the data");
Guolin Ke's avatar
Guolin Ke committed
183
184
    }
  }
Guolin Ke's avatar
Guolin Ke committed
185
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
Guolin Ke's avatar
Guolin Ke committed
186
187
  data_size_t num_global_data = 0;
  std::vector<data_size_t> used_data_indices;
188
  auto bin_filename = CheckCanLoadFromBin(filename);
189
  bool is_load_from_binary = false;
190
  if (bin_filename.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
191
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_));
Guolin Ke's avatar
Guolin Ke committed
192
193
194
195
    if (parser == nullptr) {
      Log::Fatal("Could not recognize data format of %s", filename);
    }
    dataset->data_filename_ = filename;
Guolin Ke's avatar
Guolin Ke committed
196
    dataset->label_idx_ = label_idx_;
197
    dataset->metadata_.Init(filename);
Guolin Ke's avatar
Guolin Ke committed
198
    if (!config_.two_round) {
Guolin Ke's avatar
Guolin Ke committed
199
      // read data to memory
200
      auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
201
202
203
      dataset->num_data_ = static_cast<data_size_t>(text_data.size());
      // sample data
      auto sample_data = SampleTextDataFromMemory(text_data);
204
205
      CheckSampleSize(sample_data.size(),
                      static_cast<size_t>(dataset->num_data_));
Guolin Ke's avatar
Guolin Ke committed
206
      // construct feature bin mappers
Guolin Ke's avatar
Guolin Ke committed
207
      ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
208
      // initialize label
209
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
210
      // extract features
Guolin Ke's avatar
Guolin Ke committed
211
      ExtractFeaturesFromMemory(&text_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
212
213
214
215
216
217
218
219
220
      text_data.clear();
    } else {
      // sample data from file
      auto sample_data = SampleTextDataFromFile(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
      if (used_data_indices.size() > 0) {
        dataset->num_data_ = static_cast<data_size_t>(used_data_indices.size());
      } else {
        dataset->num_data_ = num_global_data;
      }
221
222
      CheckSampleSize(sample_data.size(),
                      static_cast<size_t>(dataset->num_data_));
Guolin Ke's avatar
Guolin Ke committed
223
      // construct feature bin mappers
Guolin Ke's avatar
Guolin Ke committed
224
      ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
225
      // initialize label
226
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
227
      Log::Debug("Making second pass...");
Guolin Ke's avatar
Guolin Ke committed
228
      // extract features
Guolin Ke's avatar
Guolin Ke committed
229
      ExtractFeaturesFromFile(filename, parser.get(), used_data_indices, dataset.get());
Guolin Ke's avatar
Guolin Ke committed
230
231
232
    }
  } else {
    // load data from binary file
233
234
    is_load_from_binary = true;
    Log::Info("Load from binary file %s", bin_filename.c_str());
235
    dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), rank, num_machines, &num_global_data, &used_data_indices));
Guolin Ke's avatar
Guolin Ke committed
236
237
238
239
  }
  // check meta data
  dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);
  // need to check training data
240
241
  CheckDataset(dataset.get(), is_load_from_binary);

Guolin Ke's avatar
Guolin Ke committed
242
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
243
244
245
246
}



247
Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, const Dataset* train_data) {
Guolin Ke's avatar
Guolin Ke committed
248
249
  data_size_t num_global_data = 0;
  std::vector<data_size_t> used_data_indices;
Guolin Ke's avatar
Guolin Ke committed
250
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
251
252
  auto bin_filename = CheckCanLoadFromBin(filename);
  if (bin_filename.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
253
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_));
Guolin Ke's avatar
Guolin Ke committed
254
255
256
257
    if (parser == nullptr) {
      Log::Fatal("Could not recognize data format of %s", filename);
    }
    dataset->data_filename_ = filename;
Guolin Ke's avatar
Guolin Ke committed
258
    dataset->label_idx_ = label_idx_;
259
    dataset->metadata_.Init(filename);
Guolin Ke's avatar
Guolin Ke committed
260
    if (!config_.two_round) {
Guolin Ke's avatar
Guolin Ke committed
261
262
263
264
      // read data in memory
      auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, 0, 1, &num_global_data, &used_data_indices);
      dataset->num_data_ = static_cast<data_size_t>(text_data.size());
      // initialize label
265
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
266
      dataset->CreateValid(train_data);
Guolin Ke's avatar
Guolin Ke committed
267
      // extract features
Guolin Ke's avatar
Guolin Ke committed
268
      ExtractFeaturesFromMemory(&text_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
269
270
      text_data.clear();
    } else {
Guolin Ke's avatar
Guolin Ke committed
271
      TextReader<data_size_t> text_reader(filename, config_.header);
Guolin Ke's avatar
Guolin Ke committed
272
273
274
275
      // Get number of lines of data file
      dataset->num_data_ = static_cast<data_size_t>(text_reader.CountLine());
      num_global_data = dataset->num_data_;
      // initialize label
276
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
277
      dataset->CreateValid(train_data);
Guolin Ke's avatar
Guolin Ke committed
278
      // extract features
Guolin Ke's avatar
Guolin Ke committed
279
      ExtractFeaturesFromFile(filename, parser.get(), used_data_indices, dataset.get());
Guolin Ke's avatar
Guolin Ke committed
280
281
282
    }
  } else {
    // load data from binary file
283
    dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), 0, 1, &num_global_data, &used_data_indices));
Guolin Ke's avatar
Guolin Ke committed
284
285
286
287
  }
  // not need to check validation data
  // check meta data
  dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
288
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
289
290
}

291
292
293
Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* bin_filename,
                                        int rank, int num_machines, int* num_global_data,
                                        std::vector<data_size_t>* used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
294
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
295
  auto reader = VirtualFileReader::Make(bin_filename);
Guolin Ke's avatar
Guolin Ke committed
296
  dataset->data_filename_ = data_filename;
297
  if (!reader->Init()) {
Guolin Ke's avatar
Guolin Ke committed
298
299
300
301
302
    Log::Fatal("Could not read binary data from %s", bin_filename);
  }

  // buffer to read binary file
  size_t buffer_size = 16 * 1024 * 1024;
Guolin Ke's avatar
Guolin Ke committed
303
  auto buffer = std::vector<char>(buffer_size);
304

305
306
  // check token
  size_t size_of_token = std::strlen(Dataset::binary_file_token);
307
308
309
310
  size_t read_cnt = reader->Read(
      buffer.data(),
      VirtualFileWriter::AlignedSize(sizeof(char) * size_of_token));
  if (read_cnt < sizeof(char) * size_of_token) {
311
312
313
    Log::Fatal("Binary file error: token has the wrong size");
  }
  if (std::string(buffer.data()) != std::string(Dataset::binary_file_token)) {
314
    Log::Fatal("Input file is not LightGBM binary file");
315
  }
Guolin Ke's avatar
Guolin Ke committed
316
317

  // read size of header
318
  read_cnt = reader->Read(buffer.data(), sizeof(size_t));
Guolin Ke's avatar
Guolin Ke committed
319

320
  if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
321
322
323
    Log::Fatal("Binary file error: header has the wrong size");
  }

Guolin Ke's avatar
Guolin Ke committed
324
  size_t size_of_head = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
325
326
327
328

  // re-allocmate space if not enough
  if (size_of_head > buffer_size) {
    buffer_size = size_of_head;
Guolin Ke's avatar
Guolin Ke committed
329
    buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
330
331
  }
  // read header
332
  read_cnt = reader->Read(buffer.data(), size_of_head);
Guolin Ke's avatar
Guolin Ke committed
333
334
335
336
337

  if (read_cnt != size_of_head) {
    Log::Fatal("Binary file error: header is incorrect");
  }
  // get header
Guolin Ke's avatar
Guolin Ke committed
338
  const char* mem_ptr = buffer.data();
Guolin Ke's avatar
Guolin Ke committed
339
  dataset->num_data_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
340
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->num_data_));
Guolin Ke's avatar
Guolin Ke committed
341
  dataset->num_features_ = *(reinterpret_cast<const int*>(mem_ptr));
342
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->num_features_));
Guolin Ke's avatar
Guolin Ke committed
343
  dataset->num_total_features_ = *(reinterpret_cast<const int*>(mem_ptr));
344
345
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(dataset->num_total_features_));
Guolin Ke's avatar
Guolin Ke committed
346
  dataset->label_idx_ = *(reinterpret_cast<const int*>(mem_ptr));
347
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->label_idx_));
348
  dataset->max_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
349
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->max_bin_));
350
  dataset->bin_construct_sample_cnt_ = *(reinterpret_cast<const int*>(mem_ptr));
351
352
  mem_ptr += VirtualFileWriter::AlignedSize(
      sizeof(dataset->bin_construct_sample_cnt_));
353
  dataset->min_data_in_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
354
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->min_data_in_bin_));
355
  dataset->use_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
356
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->use_missing_));
357
  dataset->zero_as_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
358
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->zero_as_missing_));
Guolin Ke's avatar
Guolin Ke committed
359
360
  const int* tmp_feature_map = reinterpret_cast<const int*>(mem_ptr);
  dataset->used_feature_map_.clear();
Guolin Ke's avatar
Guolin Ke committed
361
  for (int i = 0; i < dataset->num_total_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
362
363
    dataset->used_feature_map_.push_back(tmp_feature_map[i]);
  }
364
365
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int) *
                                            dataset->num_total_features_);
Guolin Ke's avatar
Guolin Ke committed
366
367
  // num_groups
  dataset->num_groups_ = *(reinterpret_cast<const int*>(mem_ptr));
368
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->num_groups_));
Guolin Ke's avatar
Guolin Ke committed
369
370
371
372
373
374
  // real_feature_idx_
  const int* tmp_ptr_real_feature_idx_ = reinterpret_cast<const int*>(mem_ptr);
  dataset->real_feature_idx_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->real_feature_idx_.push_back(tmp_ptr_real_feature_idx_[i]);
  }
375
376
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * dataset->num_features_);
Guolin Ke's avatar
Guolin Ke committed
377
378
379
380
381
382
  // feature2group
  const int* tmp_ptr_feature2group = reinterpret_cast<const int*>(mem_ptr);
  dataset->feature2group_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->feature2group_.push_back(tmp_ptr_feature2group[i]);
  }
383
384
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * dataset->num_features_);
Guolin Ke's avatar
Guolin Ke committed
385
386
387
388
389
390
  // feature2subfeature
  const int* tmp_ptr_feature2subfeature = reinterpret_cast<const int*>(mem_ptr);
  dataset->feature2subfeature_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->feature2subfeature_.push_back(tmp_ptr_feature2subfeature[i]);
  }
391
392
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * dataset->num_features_);
Guolin Ke's avatar
Guolin Ke committed
393
394
395
396
397
398
399
400
401
402
403
  // group_bin_boundaries
  const uint64_t* tmp_ptr_group_bin_boundaries = reinterpret_cast<const uint64_t*>(mem_ptr);
  dataset->group_bin_boundaries_.clear();
  for (int i = 0; i < dataset->num_groups_ + 1; ++i) {
    dataset->group_bin_boundaries_.push_back(tmp_ptr_group_bin_boundaries[i]);
  }
  mem_ptr += sizeof(uint64_t) * (dataset->num_groups_ + 1);

  // group_feature_start_
  const int* tmp_ptr_group_feature_start = reinterpret_cast<const int*>(mem_ptr);
  dataset->group_feature_start_.clear();
404
  for (int i = 0; i < dataset->num_groups_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
405
406
    dataset->group_feature_start_.push_back(tmp_ptr_group_feature_start[i]);
  }
407
408
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * (dataset->num_groups_));
Guolin Ke's avatar
Guolin Ke committed
409
410
411
412
413
414
415

  // group_feature_cnt_
  const int* tmp_ptr_group_feature_cnt = reinterpret_cast<const int*>(mem_ptr);
  dataset->group_feature_cnt_.clear();
  for (int i = 0; i < dataset->num_groups_; ++i) {
    dataset->group_feature_cnt_.push_back(tmp_ptr_group_feature_cnt[i]);
  }
416
417
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * (dataset->num_groups_));
Guolin Ke's avatar
Guolin Ke committed
418

Belinda Trotta's avatar
Belinda Trotta committed
419
  if (!config_.max_bin_by_feature.empty()) {
420
421
    CHECK_EQ(static_cast<size_t>(dataset->num_total_features_), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
Belinda Trotta's avatar
Belinda Trotta committed
422
423
424
425
426
427
428
429
430
    dataset->max_bin_by_feature_.resize(dataset->num_total_features_);
    dataset->max_bin_by_feature_.assign(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end());
  } else {
    const int32_t* tmp_ptr_max_bin_by_feature = reinterpret_cast<const int32_t*>(mem_ptr);
    dataset->max_bin_by_feature_.clear();
    for (int i = 0; i < dataset->num_total_features_; ++i) {
      dataset->max_bin_by_feature_.push_back(tmp_ptr_max_bin_by_feature[i]);
    }
  }
431
432
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int32_t) *
                                            (dataset->num_total_features_));
Belinda Trotta's avatar
Belinda Trotta committed
433
434
435
436
  if (ArrayArgs<int32_t>::CheckAll(dataset->max_bin_by_feature_, -1)) {
    dataset->max_bin_by_feature_.clear();
  }

Guolin Ke's avatar
Guolin Ke committed
437
  // get feature names
Guolin Ke's avatar
Guolin Ke committed
438
  dataset->feature_names_.clear();
Guolin Ke's avatar
Guolin Ke committed
439
440
441
  // write feature names
  for (int i = 0; i < dataset->num_total_features_; ++i) {
    int str_len = *(reinterpret_cast<const int*>(mem_ptr));
442
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int));
Guolin Ke's avatar
Guolin Ke committed
443
    std::stringstream str_buf;
444
    auto tmp_arr = reinterpret_cast<const char*>(mem_ptr);
Guolin Ke's avatar
Guolin Ke committed
445
    for (int j = 0; j < str_len; ++j) {
446
      char tmp_char = tmp_arr[j];
Guolin Ke's avatar
Guolin Ke committed
447
448
      str_buf << tmp_char;
    }
449
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(char) * str_len);
Guolin Ke's avatar
Guolin Ke committed
450
    dataset->feature_names_.emplace_back(str_buf.str());
Guolin Ke's avatar
Guolin Ke committed
451
  }
452
453
454
455
  // get forced_bin_bounds_
  dataset->forced_bin_bounds_ = std::vector<std::vector<double>>(dataset->num_total_features_, std::vector<double>());
  for (int i = 0; i < dataset->num_total_features_; ++i) {
    int num_bounds = *(reinterpret_cast<const int*>(mem_ptr));
456
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int));
457
    dataset->forced_bin_bounds_[i] = std::vector<double>();
458
459
    const double* tmp_ptr_forced_bounds =
        reinterpret_cast<const double*>(mem_ptr);
460
461
462
463
464
465
    for (int j = 0; j < num_bounds; ++j) {
      double bound = tmp_ptr_forced_bounds[j];
      dataset->forced_bin_bounds_[i].push_back(bound);
    }
    mem_ptr += num_bounds * sizeof(double);
  }
Guolin Ke's avatar
Guolin Ke committed
466
467

  // read size of meta data
468
  read_cnt = reader->Read(buffer.data(), sizeof(size_t));
Guolin Ke's avatar
Guolin Ke committed
469

470
  if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
471
472
473
    Log::Fatal("Binary file error: meta data has the wrong size");
  }

Guolin Ke's avatar
Guolin Ke committed
474
  size_t size_of_metadata = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
475
476
477
478

  // re-allocate space if not enough
  if (size_of_metadata > buffer_size) {
    buffer_size = size_of_metadata;
Guolin Ke's avatar
Guolin Ke committed
479
    buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
480
481
  }
  //  read meta data
482
  read_cnt = reader->Read(buffer.data(), size_of_metadata);
Guolin Ke's avatar
Guolin Ke committed
483
484
485
486
487

  if (read_cnt != size_of_metadata) {
    Log::Fatal("Binary file error: meta data is incorrect");
  }
  // load meta data
Guolin Ke's avatar
Guolin Ke committed
488
  dataset->metadata_.LoadFromMemory(buffer.data());
Guolin Ke's avatar
Guolin Ke committed
489

490
491
  *num_global_data = dataset->num_data_;
  used_data_indices->clear();
Guolin Ke's avatar
Guolin Ke committed
492
  // sample local used data if need to partition
Guolin Ke's avatar
Guolin Ke committed
493
  if (num_machines > 1 && !config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
494
495
496
497
    const data_size_t* query_boundaries = dataset->metadata_.query_boundaries();
    if (query_boundaries == nullptr) {
      // if not contain query file, minimal sample unit is one record
      for (data_size_t i = 0; i < dataset->num_data_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
498
        if (random_.NextShort(0, num_machines) == rank) {
499
          used_data_indices->push_back(i);
Guolin Ke's avatar
Guolin Ke committed
500
501
502
503
504
505
506
507
508
        }
      }
    } else {
      // if contain query file, minimal sample unit is one query
      data_size_t num_queries = dataset->metadata_.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      for (data_size_t i = 0; i < dataset->num_data_; ++i) {
        if (qid >= num_queries) {
509
510
          Log::Fatal("Current query exceeds the range of the query file,\n"
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
511
512
513
514
        }
        if (i >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
515
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
516
517
518
519
520
            is_query_used = true;
          }
          ++qid;
        }
        if (is_query_used) {
521
          used_data_indices->push_back(i);
Guolin Ke's avatar
Guolin Ke committed
522
523
524
        }
      }
    }
525
    dataset->num_data_ = static_cast<data_size_t>((*used_data_indices).size());
Guolin Ke's avatar
Guolin Ke committed
526
  }
527
  dataset->metadata_.PartitionLabel(*used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
528
  // read feature data
Guolin Ke's avatar
Guolin Ke committed
529
  for (int i = 0; i < dataset->num_groups_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
530
    // read feature size
531
532
    read_cnt = reader->Read(buffer.data(), sizeof(size_t));
    if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
533
534
      Log::Fatal("Binary file error: feature %d has the wrong size", i);
    }
Guolin Ke's avatar
Guolin Ke committed
535
    size_t size_of_feature = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
536
537
538
    // re-allocate space if not enough
    if (size_of_feature > buffer_size) {
      buffer_size = size_of_feature;
Guolin Ke's avatar
Guolin Ke committed
539
      buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
540
541
    }

542
    read_cnt = reader->Read(buffer.data(), size_of_feature);
Guolin Ke's avatar
Guolin Ke committed
543
544
545
546

    if (read_cnt != size_of_feature) {
      Log::Fatal("Binary file error: feature %d is incorrect, read count: %d", i, read_cnt);
    }
Guolin Ke's avatar
Guolin Ke committed
547
    dataset->feature_groups_.emplace_back(std::unique_ptr<FeatureGroup>(
548
549
      new FeatureGroup(buffer.data(),
                       *num_global_data,
550
                       *used_data_indices, i)));
Guolin Ke's avatar
Guolin Ke committed
551
  }
Guolin Ke's avatar
Guolin Ke committed
552
  dataset->feature_groups_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
553
  dataset->is_finish_load_ = true;
Guolin Ke's avatar
Guolin Ke committed
554
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
555
556
}

557

558
559
560
Dataset* DatasetLoader::ConstructFromSampleData(double** sample_values,
                                                int** sample_indices, int num_col, const int* num_per_col,
                                                size_t total_sample_size, data_size_t num_data) {
561
  CheckSampleSize(total_sample_size, static_cast<size_t>(num_data));
562
563
564
565
566
  int num_total_features = num_col;
  if (Network::num_machines() > 1) {
    num_total_features = Network::GlobalSyncUpByMax(num_total_features);
  }
  std::vector<std::unique_ptr<BinMapper>> bin_mappers(num_total_features);
567
568
  // fill feature_names_ if not header
  if (feature_names_.empty()) {
569
    for (int i = 0; i < num_col; ++i) {
570
571
572
573
574
      std::stringstream str_buf;
      str_buf << "Column_" << i;
      feature_names_.push_back(str_buf.str());
    }
  }
Belinda Trotta's avatar
Belinda Trotta committed
575
  if (!config_.max_bin_by_feature.empty()) {
576
577
    CHECK_EQ(static_cast<size_t>(num_col), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
Belinda Trotta's avatar
Belinda Trotta committed
578
  }
579
580
581
582
583

  // get forced split
  std::string forced_bins_path = config_.forcedbins_filename;
  std::vector<std::vector<double>> forced_bin_bounds = DatasetLoader::GetForcedBins(forced_bins_path, num_col, categorical_features_);

Guolin Ke's avatar
Guolin Ke committed
584
  const data_size_t filter_cnt = static_cast<data_size_t>(
Guolin Ke's avatar
Guolin Ke committed
585
    static_cast<double>(config_.min_data_in_leaf * total_sample_size) / num_data);
586
587
588
589
590
591
592
593
594
595
596
597
598
  if (Network::num_machines() == 1) {
    // if only one machine, find bin locally
    OMP_INIT_EX();
    #pragma omp parallel for schedule(guided)
    for (int i = 0; i < num_col; ++i) {
      OMP_LOOP_EX_BEGIN();
      if (ignore_features_.count(i) > 0) {
        bin_mappers[i] = nullptr;
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(i)) {
        bin_type = BinType::CategoricalBin;
599
600
601
602
        bool feat_is_unconstrained = ((config_.monotone_constraints.size() == 0) || (config_.monotone_constraints[i] == 0));
        if (!feat_is_unconstrained) {
            Log::Fatal("The output cannot be monotone with respect to categorical features");
        }
603
604
      }
      bin_mappers[i].reset(new BinMapper());
Belinda Trotta's avatar
Belinda Trotta committed
605
606
      if (config_.max_bin_by_feature.empty()) {
        bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
607
                                config_.max_bin, config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter,
608
609
                                bin_type, config_.use_missing, config_.zero_as_missing,
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
610
611
      } else {
        bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
612
                                config_.max_bin_by_feature[i], config_.min_data_in_bin,
613
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
614
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
615
      }
616
617
618
619
620
621
622
623
624
625
626
627
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();
  } else {
    // if have multi-machines, need to find bin distributed
    // different machines will find bin for different features
    int num_machines = Network::num_machines();
    int rank = Network::rank();
    // start and len will store the process feature indices for different machines
    // machine i will find bins for features in [ start[i], start[i] + len[i] )
    std::vector<int> start(num_machines);
    std::vector<int> len(num_machines);
628
    int step = (num_total_features + num_machines - 1) / num_machines;
629
630
631
632
    if (step < 1) { step = 1; }

    start[0] = 0;
    for (int i = 0; i < num_machines - 1; ++i) {
633
      len[i] = std::min(step, num_total_features - start[i]);
634
635
      start[i + 1] = start[i] + len[i];
    }
636
    len[num_machines - 1] = num_total_features - start[num_machines - 1];
637
638
639
640
641
642
643
644
645
646
647
648
    OMP_INIT_EX();
    #pragma omp parallel for schedule(guided)
    for (int i = 0; i < len[rank]; ++i) {
      OMP_LOOP_EX_BEGIN();
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(start[rank] + i)) {
        bin_type = BinType::CategoricalBin;
      }
      bin_mappers[i].reset(new BinMapper());
649
650
651
      if (num_col <= start[rank] + i) {
        continue;
      }
Belinda Trotta's avatar
Belinda Trotta committed
652
      if (config_.max_bin_by_feature.empty()) {
653
654
        bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
                                total_sample_size, config_.max_bin, config_.min_data_in_bin,
655
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
656
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
657
      } else {
658
659
        bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
                                total_sample_size, config_.max_bin_by_feature[start[rank] + i],
660
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
661
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
662
      }
663
664
      OMP_LOOP_EX_END();
    }
Guolin Ke's avatar
Guolin Ke committed
665
    OMP_THROW_EX();
666
    comm_size_t self_buf_size = 0;
667
    for (int i = 0; i < len[rank]; ++i) {
668
669
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
670
      }
671
      self_buf_size += static_cast<comm_size_t>(bin_mappers[i]->SizesInByte());
Guolin Ke's avatar
Guolin Ke committed
672
    }
673
674
    std::vector<char> input_buffer(self_buf_size);
    auto cp_ptr = input_buffer.data();
675
676
677
678
    for (int i = 0; i < len[rank]; ++i) {
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
679
680
      bin_mappers[i]->CopyTo(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
681
682
683
      // free
      bin_mappers[i].reset(nullptr);
    }
684
685
686
687
    std::vector<comm_size_t> size_len = Network::GlobalArray(self_buf_size);
    std::vector<comm_size_t> size_start(num_machines, 0);
    for (int i = 1; i < num_machines; ++i) {
      size_start[i] = size_start[i - 1] + size_len[i - 1];
688
    }
689
690
    comm_size_t total_buffer_size = size_start[num_machines - 1] + size_len[num_machines - 1];
    std::vector<char> output_buffer(total_buffer_size);
691
    // gather global feature bin mappers
692
693
    Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), total_buffer_size);
    cp_ptr = output_buffer.data();
694
    // restore features bins from buffer
695
    for (int i = 0; i < num_total_features; ++i) {
696
697
698
699
700
      if (ignore_features_.count(i) > 0) {
        bin_mappers[i] = nullptr;
        continue;
      }
      bin_mappers[i].reset(new BinMapper());
701
702
      bin_mappers[i]->CopyFrom(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
703
    }
Guolin Ke's avatar
Guolin Ke committed
704
  }
Guolin Ke's avatar
Guolin Ke committed
705
  auto dataset = std::unique_ptr<Dataset>(new Dataset(num_data));
Guolin Ke's avatar
Guolin Ke committed
706
  dataset->Construct(&bin_mappers, num_total_features, forced_bin_bounds, sample_indices, sample_values, num_per_col, num_col, total_sample_size, config_);
707
  dataset->set_feature_names(feature_names_);
Guolin Ke's avatar
Guolin Ke committed
708
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
709
}
Guolin Ke's avatar
Guolin Ke committed
710
711
712
713


// ---- private functions ----

714
void DatasetLoader::CheckDataset(const Dataset* dataset, bool is_load_from_binary) {
Guolin Ke's avatar
Guolin Ke committed
715
  if (dataset->num_data_ <= 0) {
Guolin Ke's avatar
Guolin Ke committed
716
    Log::Fatal("Data file %s is empty", dataset->data_filename_.c_str());
Guolin Ke's avatar
Guolin Ke committed
717
  }
718
719
  if (dataset->feature_names_.size() != static_cast<size_t>(dataset->num_total_features_)) {
    Log::Fatal("Size of feature name error, should be %d, got %d", dataset->num_total_features_,
720
               static_cast<int>(dataset->feature_names_.size()));
721
  }
Guolin Ke's avatar
Guolin Ke committed
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
  bool is_feature_order_by_group = true;
  int last_group = -1;
  int last_sub_feature = -1;
  // if features are ordered, not need to use hist_buf
  for (int i = 0; i < dataset->num_features_; ++i) {
    int group = dataset->feature2group_[i];
    int sub_feature = dataset->feature2subfeature_[i];
    if (group < last_group) {
      is_feature_order_by_group = false;
    } else if (group == last_group) {
      if (sub_feature <= last_sub_feature) {
        is_feature_order_by_group = false;
        break;
      }
    }
    last_group = group;
    last_sub_feature = sub_feature;
  }
  if (!is_feature_order_by_group) {
741
    Log::Fatal("Features in dataset should be ordered by group");
Guolin Ke's avatar
Guolin Ke committed
742
  }
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774

  if (is_load_from_binary) {
    if (dataset->max_bin_ != config_.max_bin) {
      Log::Fatal("Dataset max_bin %d != config %d", dataset->max_bin_, config_.max_bin);
    }
    if (dataset->min_data_in_bin_ != config_.min_data_in_bin) {
      Log::Fatal("Dataset min_data_in_bin %d != config %d", dataset->min_data_in_bin_, config_.min_data_in_bin);
    }
    if (dataset->use_missing_ != config_.use_missing) {
      Log::Fatal("Dataset use_missing %d != config %d", dataset->use_missing_, config_.use_missing);
    }
    if (dataset->zero_as_missing_ != config_.zero_as_missing) {
      Log::Fatal("Dataset zero_as_missing %d != config %d", dataset->zero_as_missing_, config_.zero_as_missing);
    }
    if (dataset->bin_construct_sample_cnt_ != config_.bin_construct_sample_cnt) {
      Log::Fatal("Dataset bin_construct_sample_cnt %d != config %d", dataset->bin_construct_sample_cnt_, config_.bin_construct_sample_cnt);
    }
    if ((dataset->max_bin_by_feature_.size() != config_.max_bin_by_feature.size()) ||
        !std::equal(dataset->max_bin_by_feature_.begin(), dataset->max_bin_by_feature_.end(),
            config_.max_bin_by_feature.begin())) {
      Log::Fatal("Dataset max_bin_by_feature does not match with config");
    }

    int label_idx = -1;
    if (Common::AtoiAndCheck(config_.label_column.c_str(), &label_idx)) {
      if (dataset->label_idx_ != label_idx) {
        Log::Fatal("Dataset label_idx %d != config %d", dataset->zero_as_missing_, config_.zero_as_missing);
      }
    } else {
      Log::Info("Recommend use integer for label index when loading data from binary for sanity check.");
    }
  }
Guolin Ke's avatar
Guolin Ke committed
775
776
777
}

std::vector<std::string> DatasetLoader::LoadTextDataToMemory(const char* filename, const Metadata& metadata,
778
779
                                                             int rank, int num_machines, int* num_global_data,
                                                             std::vector<data_size_t>* used_data_indices) {
780
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
781
  used_data_indices->clear();
Guolin Ke's avatar
Guolin Ke committed
782
  if (num_machines == 1 || config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
783
784
785
786
787
788
789
790
791
    // read all lines
    *num_global_data = text_reader.ReadAllLines();
  } else {  // need partition data
            // get query data
    const data_size_t* query_boundaries = metadata.query_boundaries();

    if (query_boundaries == nullptr) {
      // if not contain query data, minimal sample unit is one record
      *num_global_data = text_reader.ReadAndFilterLines([this, rank, num_machines](data_size_t) {
Guolin Ke's avatar
Guolin Ke committed
792
        if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
793
794
795
796
797
798
799
800
801
802
803
804
805
806
          return true;
        } else {
          return false;
        }
      }, used_data_indices);
    } else {
      // if contain query data, minimal sample unit is one query
      data_size_t num_queries = metadata.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      *num_global_data = text_reader.ReadAndFilterLines(
        [this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries]
      (data_size_t line_idx) {
        if (qid >= num_queries) {
807
808
          Log::Fatal("Current query exceeds the range of the query file,\n"
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
809
810
811
812
        }
        if (line_idx >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
813
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
814
815
816
817
818
819
820
821
822
823
824
825
            is_query_used = true;
          }
          ++qid;
        }
        return is_query_used;
      }, used_data_indices);
    }
  }
  return std::move(text_reader.Lines());
}

std::vector<std::string> DatasetLoader::SampleTextDataFromMemory(const std::vector<std::string>& data) {
Guolin Ke's avatar
Guolin Ke committed
826
  int sample_cnt = config_.bin_construct_sample_cnt;
827
828
  if (static_cast<size_t>(sample_cnt) > data.size()) {
    sample_cnt = static_cast<int>(data.size());
829
  }
830
  auto sample_indices = random_.Sample(static_cast<int>(data.size()), sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
831
  std::vector<std::string> out(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
832
833
  for (size_t i = 0; i < sample_indices.size(); ++i) {
    const size_t idx = sample_indices[i];
Guolin Ke's avatar
Guolin Ke committed
834
    out[i] = data[idx];
Guolin Ke's avatar
Guolin Ke committed
835
836
837
838
  }
  return out;
}

839
840
841
std::vector<std::string> DatasetLoader::SampleTextDataFromFile(const char* filename, const Metadata& metadata,
                                                               int rank, int num_machines, int* num_global_data,
                                                               std::vector<data_size_t>* used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
842
  const data_size_t sample_cnt = static_cast<data_size_t>(config_.bin_construct_sample_cnt);
843
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
844
  std::vector<std::string> out_data;
Guolin Ke's avatar
Guolin Ke committed
845
  if (num_machines == 1 || config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
846
    *num_global_data = static_cast<data_size_t>(text_reader.SampleFromFile(&random_, sample_cnt, &out_data));
Guolin Ke's avatar
Guolin Ke committed
847
848
849
850
851
852
853
  } else {  // need partition data
            // get query data
    const data_size_t* query_boundaries = metadata.query_boundaries();
    if (query_boundaries == nullptr) {
      // if not contain query file, minimal sample unit is one record
      *num_global_data = text_reader.SampleAndFilterFromFile([this, rank, num_machines]
      (data_size_t) {
Guolin Ke's avatar
Guolin Ke committed
854
        if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
855
856
857
858
          return true;
        } else {
          return false;
        }
Guolin Ke's avatar
Guolin Ke committed
859
      }, used_data_indices, &random_, sample_cnt, &out_data);
Guolin Ke's avatar
Guolin Ke committed
860
861
862
863
864
865
866
867
868
    } else {
      // if contain query file, minimal sample unit is one query
      data_size_t num_queries = metadata.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      *num_global_data = text_reader.SampleAndFilterFromFile(
        [this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries]
      (data_size_t line_idx) {
        if (qid >= num_queries) {
869
870
          Log::Fatal("Query id exceeds the range of the query file, "
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
871
872
873
874
        }
        if (line_idx >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
875
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
876
877
878
879
880
            is_query_used = true;
          }
          ++qid;
        }
        return is_query_used;
Guolin Ke's avatar
Guolin Ke committed
881
      }, used_data_indices, &random_, sample_cnt, &out_data);
Guolin Ke's avatar
Guolin Ke committed
882
883
884
885
886
    }
  }
  return out_data;
}

887
888
889
void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
                                                    const std::vector<std::string>& sample_data,
                                                    const Parser* parser, Dataset* dataset) {
Guolin Ke's avatar
Guolin Ke committed
890
  std::vector<std::vector<double>> sample_values;
Guolin Ke's avatar
Guolin Ke committed
891
  std::vector<std::vector<int>> sample_indices;
Guolin Ke's avatar
Guolin Ke committed
892
893
  std::vector<std::pair<int, double>> oneline_features;
  double label;
Guolin Ke's avatar
Guolin Ke committed
894
  for (int i = 0; i < static_cast<int>(sample_data.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
895
896
897
898
    oneline_features.clear();
    // parse features
    parser->ParseOneLine(sample_data[i].c_str(), &oneline_features, &label);
    for (std::pair<int, double>& inner_data : oneline_features) {
899
      if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
Guolin Ke's avatar
Guolin Ke committed
900
901
        sample_values.resize(inner_data.first + 1);
        sample_indices.resize(inner_data.first + 1);
902
      }
Guolin Ke's avatar
Guolin Ke committed
903
      if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
Guolin Ke's avatar
Guolin Ke committed
904
905
        sample_values[inner_data.first].emplace_back(inner_data.second);
        sample_indices[inner_data.first].emplace_back(i);
Guolin Ke's avatar
Guolin Ke committed
906
907
908
909
      }
    }
  }

Guolin Ke's avatar
Guolin Ke committed
910
  dataset->feature_groups_.clear();
911
912
913
914
915
  dataset->num_total_features_ = std::max(static_cast<int>(sample_values.size()), parser->NumFeatures());
  if (num_machines > 1) {
    dataset->num_total_features_ = Network::GlobalSyncUpByMax(dataset->num_total_features_);
  }
  if (!feature_names_.empty()) {
916
    CHECK_EQ(dataset->num_total_features_, static_cast<int>(feature_names_.size()));
917
  }
Guolin Ke's avatar
Guolin Ke committed
918

Belinda Trotta's avatar
Belinda Trotta committed
919
  if (!config_.max_bin_by_feature.empty()) {
920
921
    CHECK_EQ(static_cast<size_t>(dataset->num_total_features_), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
Belinda Trotta's avatar
Belinda Trotta committed
922
923
  }

924
925
  // get forced split
  std::string forced_bins_path = config_.forcedbins_filename;
926
927
  std::vector<std::vector<double>> forced_bin_bounds = DatasetLoader::GetForcedBins(forced_bins_path,
                                                                                    dataset->num_total_features_,
928
929
                                                                                    categorical_features_);

Guolin Ke's avatar
Guolin Ke committed
930
931
932
933
934
935
  // check the range of label_idx, weight_idx and group_idx
  CHECK(label_idx_ >= 0 && label_idx_ <= dataset->num_total_features_);
  CHECK(weight_idx_ < 0 || weight_idx_ < dataset->num_total_features_);
  CHECK(group_idx_ < 0 || group_idx_ < dataset->num_total_features_);

  // fill feature_names_ if not header
Guolin Ke's avatar
Guolin Ke committed
936
  if (feature_names_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
937
938
939
940
941
942
    for (int i = 0; i < dataset->num_total_features_; ++i) {
      std::stringstream str_buf;
      str_buf << "Column_" << i;
      feature_names_.push_back(str_buf.str());
    }
  }
943
  dataset->set_feature_names(feature_names_);
Guolin Ke's avatar
Guolin Ke committed
944
  std::vector<std::unique_ptr<BinMapper>> bin_mappers(dataset->num_total_features_);
Guolin Ke's avatar
Guolin Ke committed
945
  const data_size_t filter_cnt = static_cast<data_size_t>(
Guolin Ke's avatar
Guolin Ke committed
946
    static_cast<double>(config_.min_data_in_leaf* sample_data.size()) / dataset->num_data_);
Guolin Ke's avatar
Guolin Ke committed
947
948
949
  // start find bins
  if (num_machines == 1) {
    // if only one machine, find bin locally
950
    OMP_INIT_EX();
951
    #pragma omp parallel for schedule(guided)
Guolin Ke's avatar
Guolin Ke committed
952
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
953
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
954
      if (ignore_features_.count(i) > 0) {
Guolin Ke's avatar
Guolin Ke committed
955
        bin_mappers[i] = nullptr;
Guolin Ke's avatar
Guolin Ke committed
956
957
        continue;
      }
958
959
960
961
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(i)) {
        bin_type = BinType::CategoricalBin;
      }
Guolin Ke's avatar
Guolin Ke committed
962
      bin_mappers[i].reset(new BinMapper());
Belinda Trotta's avatar
Belinda Trotta committed
963
964
      if (config_.max_bin_by_feature.empty()) {
        bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
965
                                sample_data.size(), config_.max_bin, config_.min_data_in_bin,
966
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
967
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
968
969
      } else {
        bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
970
                                sample_data.size(), config_.max_bin_by_feature[i],
971
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
972
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
973
      }
974
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
975
    }
976
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
977
978
  } else {
    // start and len will store the process feature indices for different machines
979
    // machine i will find bins for features in [ start[i], start[i] + len[i] )
Guolin Ke's avatar
Guolin Ke committed
980
981
    std::vector<int> start(num_machines);
    std::vector<int> len(num_machines);
982
    int step = (dataset->num_total_features_ + num_machines - 1) / num_machines;
Guolin Ke's avatar
Guolin Ke committed
983
984
985
986
    if (step < 1) { step = 1; }

    start[0] = 0;
    for (int i = 0; i < num_machines - 1; ++i) {
987
      len[i] = std::min(step, dataset->num_total_features_ - start[i]);
Guolin Ke's avatar
Guolin Ke committed
988
989
      start[i + 1] = start[i] + len[i];
    }
990
    len[num_machines - 1] = dataset->num_total_features_ - start[num_machines - 1];
991
    OMP_INIT_EX();
992
    #pragma omp parallel for schedule(guided)
993
    for (int i = 0; i < len[rank]; ++i) {
994
      OMP_LOOP_EX_BEGIN();
995
996
997
998
999
1000
1001
1002
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(start[rank] + i)) {
        bin_type = BinType::CategoricalBin;
      }
      bin_mappers[i].reset(new BinMapper());
Nikita Titov's avatar
Nikita Titov committed
1003
      if (static_cast<int>(sample_values.size()) <= start[rank] + i) {
1004
1005
        continue;
      }
Belinda Trotta's avatar
Belinda Trotta committed
1006
      if (config_.max_bin_by_feature.empty()) {
1007
        bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
Belinda Trotta's avatar
Belinda Trotta committed
1008
                                static_cast<int>(sample_values[start[rank] + i].size()),
1009
                                sample_data.size(), config_.max_bin, config_.min_data_in_bin,
1010
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
1011
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
1012
      } else {
1013
        bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
Belinda Trotta's avatar
Belinda Trotta committed
1014
                                static_cast<int>(sample_values[start[rank] + i].size()),
1015
                                sample_data.size(), config_.max_bin_by_feature[i],
1016
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type,
1017
                                config_.use_missing, config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
1018
      }
1019
      OMP_LOOP_EX_END();
1020
    }
1021
    OMP_THROW_EX();
1022
    comm_size_t self_buf_size = 0;
Guolin Ke's avatar
Guolin Ke committed
1023
    for (int i = 0; i < len[rank]; ++i) {
1024
1025
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
Guolin Ke's avatar
Guolin Ke committed
1026
      }
1027
      self_buf_size += static_cast<comm_size_t>(bin_mappers[i]->SizesInByte());
Guolin Ke's avatar
Guolin Ke committed
1028
    }
1029
1030
    std::vector<char> input_buffer(self_buf_size);
    auto cp_ptr = input_buffer.data();
Guolin Ke's avatar
Guolin Ke committed
1031
    for (int i = 0; i < len[rank]; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1032
1033
1034
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
1035
1036
      bin_mappers[i]->CopyTo(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
1037
1038
      // free
      bin_mappers[i].reset(nullptr);
Guolin Ke's avatar
Guolin Ke committed
1039
    }
1040
1041
1042
1043
    std::vector<comm_size_t> size_len = Network::GlobalArray(self_buf_size);
    std::vector<comm_size_t> size_start(num_machines, 0);
    for (int i = 1; i < num_machines; ++i) {
      size_start[i] = size_start[i - 1] + size_len[i - 1];
Guolin Ke's avatar
Guolin Ke committed
1044
    }
1045
1046
    comm_size_t total_buffer_size = size_start[num_machines - 1] + size_len[num_machines - 1];
    std::vector<char> output_buffer(total_buffer_size);
Guolin Ke's avatar
Guolin Ke committed
1047
    // gather global feature bin mappers
1048
1049
    Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), total_buffer_size);
    cp_ptr = output_buffer.data();
Guolin Ke's avatar
Guolin Ke committed
1050
    // restore features bins from buffer
1051
    for (int i = 0; i < dataset->num_total_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1052
      if (ignore_features_.count(i) > 0) {
Guolin Ke's avatar
Guolin Ke committed
1053
        bin_mappers[i] = nullptr;
Guolin Ke's avatar
Guolin Ke committed
1054
1055
        continue;
      }
Guolin Ke's avatar
Guolin Ke committed
1056
      bin_mappers[i].reset(new BinMapper());
1057
1058
      bin_mappers[i]->CopyFrom(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
Guolin Ke's avatar
Guolin Ke committed
1059
1060
    }
  }
1061
  dataset->Construct(&bin_mappers, dataset->num_total_features_, forced_bin_bounds, Common::Vector2Ptr<int>(&sample_indices).data(),
Guolin Ke's avatar
Guolin Ke committed
1062
                     Common::Vector2Ptr<double>(&sample_values).data(),
1063
                     Common::VectorSize<int>(sample_indices).data(), static_cast<int>(sample_indices.size()), sample_data.size(), config_);
Guolin Ke's avatar
Guolin Ke committed
1064
1065
1066
}

/*! \brief Extract local features from memory */
Guolin Ke's avatar
Guolin Ke committed
1067
void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>* text_data, const Parser* parser, Dataset* dataset) {
Guolin Ke's avatar
Guolin Ke committed
1068
1069
  std::vector<std::pair<int, double>> oneline_features;
  double tmp_label = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
1070
  auto& ref_text_data = *text_data;
Guolin Ke's avatar
Guolin Ke committed
1071
  if (predict_fun_ == nullptr) {
1072
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1073
    // if doesn't need to prediction with initial model
1074
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
Guolin Ke's avatar
Guolin Ke committed
1075
    for (data_size_t i = 0; i < dataset->num_data_; ++i) {
1076
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1077
1078
1079
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
Guolin Ke's avatar
Guolin Ke committed
1080
      parser->ParseOneLine(ref_text_data[i].c_str(), &oneline_features, &tmp_label);
Guolin Ke's avatar
Guolin Ke committed
1081
      // set label
1082
      dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1083
      // free processed line:
Guolin Ke's avatar
Guolin Ke committed
1084
      ref_text_data[i].clear();
Guolin Ke's avatar
Guolin Ke committed
1085
1086
      // shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
      // text_reader_->Lines()[i].shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
1087
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1088
1089
      // push data
      for (auto& inner_data : oneline_features) {
1090
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1091
1092
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1093
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1094
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1095
1096
1097
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
          dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
Guolin Ke's avatar
Guolin Ke committed
1098
1099
        } else {
          if (inner_data.first == weight_idx_) {
1100
            dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1101
1102
1103
1104
1105
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1106
      dataset->FinishOneRow(tid, i, is_feature_added);
1107
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1108
    }
1109
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1110
  } else {
1111
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1112
    // if need to prediction with initial model
1113
    std::vector<double> init_score(dataset->num_data_ * num_class_);
1114
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
Guolin Ke's avatar
Guolin Ke committed
1115
    for (data_size_t i = 0; i < dataset->num_data_; ++i) {
1116
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1117
1118
1119
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
Guolin Ke's avatar
Guolin Ke committed
1120
      parser->ParseOneLine(ref_text_data[i].c_str(), &oneline_features, &tmp_label);
Guolin Ke's avatar
Guolin Ke committed
1121
      // set initial score
Guolin Ke's avatar
Guolin Ke committed
1122
1123
      std::vector<double> oneline_init_score(num_class_);
      predict_fun_(oneline_features, oneline_init_score.data());
1124
      for (int k = 0; k < num_class_; ++k) {
1125
        init_score[k * dataset->num_data_ + i] = static_cast<double>(oneline_init_score[k]);
Guolin Ke's avatar
Guolin Ke committed
1126
1127
      }
      // set label
1128
      dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1129
1130
1131
1132
1133
      // free processed line:
      text_data[i].clear();
      // shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
      // text_reader_->Lines()[i].shrink_to_fit();
      // push data
Guolin Ke's avatar
Guolin Ke committed
1134
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1135
      for (auto& inner_data : oneline_features) {
1136
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1137
1138
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1139
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1140
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1141
1142
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
1143
          dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
Guolin Ke's avatar
Guolin Ke committed
1144
1145
        } else {
          if (inner_data.first == weight_idx_) {
1146
            dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1147
1148
1149
1150
1151
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1152
      dataset->FinishOneRow(tid, i, is_feature_added);
1153
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1154
    }
1155
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1156
    // metadata_ will manage space of init_score
1157
    dataset->metadata_.SetInitScore(init_score.data(), dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1158
  }
Guolin Ke's avatar
Guolin Ke committed
1159
  dataset->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1160
  // text data can be free after loaded feature values
Guolin Ke's avatar
Guolin Ke committed
1161
  text_data->clear();
Guolin Ke's avatar
Guolin Ke committed
1162
1163
1164
}

/*! \brief Extract local features from file */
1165
1166
void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser* parser,
                                            const std::vector<data_size_t>& used_data_indices, Dataset* dataset) {
1167
  std::vector<double> init_score;
Guolin Ke's avatar
Guolin Ke committed
1168
  if (predict_fun_ != nullptr) {
1169
    init_score = std::vector<double>(dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1170
1171
1172
1173
1174
1175
  }
  std::function<void(data_size_t, const std::vector<std::string>&)> process_fun =
    [this, &init_score, &parser, &dataset]
  (data_size_t start_idx, const std::vector<std::string>& lines) {
    std::vector<std::pair<int, double>> oneline_features;
    double tmp_label = 0.0f;
1176
    OMP_INIT_EX();
1177
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
Guolin Ke's avatar
Guolin Ke committed
1178
    for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) {
1179
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1180
1181
1182
1183
1184
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
      parser->ParseOneLine(lines[i].c_str(), &oneline_features, &tmp_label);
      // set initial score
Guolin Ke's avatar
Guolin Ke committed
1185
      if (!init_score.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1186
1187
        std::vector<double> oneline_init_score(num_class_);
        predict_fun_(oneline_features, oneline_init_score.data());
1188
        for (int k = 0; k < num_class_; ++k) {
1189
          init_score[k * dataset->num_data_ + start_idx + i] = static_cast<double>(oneline_init_score[k]);
Guolin Ke's avatar
Guolin Ke committed
1190
1191
1192
        }
      }
      // set label
1193
      dataset->metadata_.SetLabelAt(start_idx + i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1194
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1195
1196
      // push data
      for (auto& inner_data : oneline_features) {
1197
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1198
1199
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1200
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1201
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1202
1203
1204
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
          dataset->feature_groups_[group]->PushData(tid, sub_feature, start_idx + i, inner_data.second);
Guolin Ke's avatar
Guolin Ke committed
1205
1206
        } else {
          if (inner_data.first == weight_idx_) {
1207
            dataset->metadata_.SetWeightAt(start_idx + i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1208
1209
1210
1211
1212
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(start_idx + i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1213
      dataset->FinishOneRow(tid, i, is_feature_added);
1214
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1215
    }
1216
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1217
  };
1218
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
1219
  if (!used_data_indices.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1220
1221
1222
1223
1224
1225
1226
1227
    // only need part of data
    text_reader.ReadPartAndProcessParallel(used_data_indices, process_fun);
  } else {
    // need full data
    text_reader.ReadAllAndProcessParallel(process_fun);
  }

  // metadata_ will manage space of init_score
Guolin Ke's avatar
Guolin Ke committed
1228
  if (!init_score.empty()) {
1229
    dataset->metadata_.SetInitScore(init_score.data(), dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1230
  }
Guolin Ke's avatar
Guolin Ke committed
1231
  dataset->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1232
1233
1234
}

/*! \brief Check can load from binary file */
1235
std::string DatasetLoader::CheckCanLoadFromBin(const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
1236
1237
1238
  std::string bin_filename(filename);
  bin_filename.append(".bin");

1239
  auto reader = VirtualFileReader::Make(bin_filename.c_str());
Guolin Ke's avatar
Guolin Ke committed
1240

1241
  if (!reader->Init()) {
1242
    bin_filename = std::string(filename);
1243
1244
    reader = VirtualFileReader::Make(bin_filename.c_str());
    if (!reader->Init()) {
1245
      Log::Fatal("Cannot open data file %s", bin_filename.c_str());
1246
    }
1247
  }
1248
1249
1250
1251
1252

  size_t buffer_size = 256;
  auto buffer = std::vector<char>(buffer_size);
  // read size of token
  size_t size_of_token = std::strlen(Dataset::binary_file_token);
1253
  size_t read_cnt = reader->Read(buffer.data(), size_of_token);
1254
1255
  if (read_cnt == size_of_token
      && std::string(buffer.data()) == std::string(Dataset::binary_file_token)) {
1256
    return bin_filename;
Guolin Ke's avatar
Guolin Ke committed
1257
  } else {
1258
    return std::string();
Guolin Ke's avatar
Guolin Ke committed
1259
1260
1261
  }
}

1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274


std::vector<std::vector<double>> DatasetLoader::GetForcedBins(std::string forced_bins_path, int num_total_features,
                                                              const std::unordered_set<int>& categorical_features) {
  std::vector<std::vector<double>> forced_bins(num_total_features, std::vector<double>());
  if (forced_bins_path != "") {
    std::ifstream forced_bins_stream(forced_bins_path.c_str());
    if (forced_bins_stream.fail()) {
      Log::Warning("Could not open %s. Will ignore.", forced_bins_path.c_str());
    } else {
      std::stringstream buffer;
      buffer << forced_bins_stream.rdbuf();
      std::string err;
Guolin Ke's avatar
Guolin Ke committed
1275
      Json forced_bins_json = Json::parse(buffer.str(), &err);
1276
1277
1278
1279
      CHECK(forced_bins_json.is_array());
      std::vector<Json> forced_bins_arr = forced_bins_json.array_items();
      for (size_t i = 0; i < forced_bins_arr.size(); ++i) {
        int feature_num = forced_bins_arr[i]["feature"].int_value();
Nikita Titov's avatar
Nikita Titov committed
1280
        CHECK_LT(feature_num, num_total_features);
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
        if (categorical_features.count(feature_num)) {
          Log::Warning("Feature %d is categorical. Will ignore forced bins for this  feature.", feature_num);
        } else {
          std::vector<Json> bounds_arr = forced_bins_arr[i]["bin_upper_bound"].array_items();
          for (size_t j = 0; j < bounds_arr.size(); ++j) {
            forced_bins[feature_num].push_back(bounds_arr[j].number_value());
          }
        }
      }
      // remove duplicates
      for (int i = 0; i < num_total_features; ++i) {
        auto new_end = std::unique(forced_bins[i].begin(), forced_bins[i].end());
        forced_bins[i].erase(new_end, forced_bins[i].end());
      }
    }
  }
  return forced_bins;
}

1300
}  // namespace LightGBM