dataset_loader.cpp 53.4 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
#include <LightGBM/dataset_loader.h>

Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/network.h>
8
#include <LightGBM/utils/array_args.h>
9
#include <LightGBM/utils/json11.h>
10
11
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
12

13
14
#include <fstream>

Guolin Ke's avatar
Guolin Ke committed
15
16
namespace LightGBM {

17
18
using json11::Json;

Guolin Ke's avatar
Guolin Ke committed
19
20
DatasetLoader::DatasetLoader(const Config& io_config, const PredictFunction& predict_fun, int num_class, const char* filename)
  :config_(io_config), random_(config_.data_random_seed), predict_fun_(predict_fun), num_class_(num_class) {
Guolin Ke's avatar
Guolin Ke committed
21
22
23
24
  label_idx_ = 0;
  weight_idx_ = NO_SPECIFIC;
  group_idx_ = NO_SPECIFIC;
  SetHeader(filename);
Guolin Ke's avatar
Guolin Ke committed
25
26
27
28
29
}

DatasetLoader::~DatasetLoader() {
}

Guolin Ke's avatar
Guolin Ke committed
30
void DatasetLoader::SetHeader(const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
31
  std::unordered_map<std::string, int> name2idx;
Guolin Ke's avatar
Guolin Ke committed
32
33
  std::string name_prefix("name:");
  if (filename != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
34
    TextReader<data_size_t> text_reader(filename, config_.header);
Guolin Ke's avatar
Guolin Ke committed
35

Guolin Ke's avatar
Guolin Ke committed
36
    // get column names
Guolin Ke's avatar
Guolin Ke committed
37
    if (config_.header) {
Guolin Ke's avatar
Guolin Ke committed
38
      std::string first_line = text_reader.first_line();
39
      feature_names_ = Common::Split(first_line.c_str(), "\t,");
Guolin Ke's avatar
Guolin Ke committed
40
41
    }

Guolin Ke's avatar
Guolin Ke committed
42
    // load label idx first
Guolin Ke's avatar
Guolin Ke committed
43
44
45
    if (config_.label_column.size() > 0) {
      if (Common::StartsWith(config_.label_column, name_prefix)) {
        std::string name = config_.label_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
46
47
48
49
50
51
52
53
54
55
        label_idx_ = -1;
        for (int i = 0; i < static_cast<int>(feature_names_.size()); ++i) {
          if (name == feature_names_[i]) {
            label_idx_ = i;
            break;
          }
        }
        if (label_idx_ >= 0) {
          Log::Info("Using column %s as label", name.c_str());
        } else {
56
57
          Log::Fatal("Could not find label column %s in data file \n"
                     "or data file doesn't contain header", name.c_str());
Guolin Ke's avatar
Guolin Ke committed
58
        }
Guolin Ke's avatar
Guolin Ke committed
59
      } else {
Guolin Ke's avatar
Guolin Ke committed
60
        if (!Common::AtoiAndCheck(config_.label_column.c_str(), &label_idx_)) {
61
62
63
          Log::Fatal("label_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
64
65
        }
        Log::Info("Using column number %d as label", label_idx_);
Guolin Ke's avatar
Guolin Ke committed
66
67
      }
    }
Guolin Ke's avatar
Guolin Ke committed
68

Guolin Ke's avatar
Guolin Ke committed
69
    if (!feature_names_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
70
71
72
73
      // erase label column name
      feature_names_.erase(feature_names_.begin() + label_idx_);
      for (size_t i = 0; i < feature_names_.size(); ++i) {
        name2idx[feature_names_[i]] = static_cast<int>(i);
Guolin Ke's avatar
Guolin Ke committed
74
      }
Guolin Ke's avatar
Guolin Ke committed
75
76
77
    }

    // load ignore columns
Guolin Ke's avatar
Guolin Ke committed
78
79
80
    if (config_.ignore_column.size() > 0) {
      if (Common::StartsWith(config_.ignore_column, name_prefix)) {
        std::string names = config_.ignore_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
81
82
83
84
85
86
87
88
89
        for (auto name : Common::Split(names.c_str(), ',')) {
          if (name2idx.count(name) > 0) {
            int tmp = name2idx[name];
            ignore_features_.emplace(tmp);
          } else {
            Log::Fatal("Could not find ignore column %s in data file", name.c_str());
          }
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
90
        for (auto token : Common::Split(config_.ignore_column.c_str(), ',')) {
Guolin Ke's avatar
Guolin Ke committed
91
92
          int tmp = 0;
          if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
93
94
95
            Log::Fatal("ignore_column is not a number,\n"
                       "if you want to use a column name,\n"
                       "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
96
97
          }
          ignore_features_.emplace(tmp);
Guolin Ke's avatar
Guolin Ke committed
98
99
100
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
101
    // load weight idx
Guolin Ke's avatar
Guolin Ke committed
102
103
104
    if (config_.weight_column.size() > 0) {
      if (Common::StartsWith(config_.weight_column, name_prefix)) {
        std::string name = config_.weight_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
105
106
107
108
109
110
        if (name2idx.count(name) > 0) {
          weight_idx_ = name2idx[name];
          Log::Info("Using column %s as weight", name.c_str());
        } else {
          Log::Fatal("Could not find weight column %s in data file", name.c_str());
        }
Guolin Ke's avatar
Guolin Ke committed
111
      } else {
Guolin Ke's avatar
Guolin Ke committed
112
        if (!Common::AtoiAndCheck(config_.weight_column.c_str(), &weight_idx_)) {
113
114
115
          Log::Fatal("weight_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
116
117
        }
        Log::Info("Using column number %d as weight", weight_idx_);
Guolin Ke's avatar
Guolin Ke committed
118
      }
Guolin Ke's avatar
Guolin Ke committed
119
      ignore_features_.emplace(weight_idx_);
Guolin Ke's avatar
Guolin Ke committed
120
    }
Guolin Ke's avatar
Guolin Ke committed
121
    // load group idx
Guolin Ke's avatar
Guolin Ke committed
122
123
124
    if (config_.group_column.size() > 0) {
      if (Common::StartsWith(config_.group_column, name_prefix)) {
        std::string name = config_.group_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
125
126
127
128
129
130
131
        if (name2idx.count(name) > 0) {
          group_idx_ = name2idx[name];
          Log::Info("Using column %s as group/query id", name.c_str());
        } else {
          Log::Fatal("Could not find group/query column %s in data file", name.c_str());
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
132
        if (!Common::AtoiAndCheck(config_.group_column.c_str(), &group_idx_)) {
133
134
135
          Log::Fatal("group_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
136
137
138
139
        }
        Log::Info("Using column number %d as group/query id", group_idx_);
      }
      ignore_features_.emplace(group_idx_);
Guolin Ke's avatar
Guolin Ke committed
140
141
    }
  }
Guolin Ke's avatar
Guolin Ke committed
142
143
144
  if (config_.categorical_feature.size() > 0) {
    if (Common::StartsWith(config_.categorical_feature, name_prefix)) {
      std::string names = config_.categorical_feature.substr(name_prefix.size());
145
146
147
148
149
      for (auto name : Common::Split(names.c_str(), ',')) {
        if (name2idx.count(name) > 0) {
          int tmp = name2idx[name];
          categorical_features_.emplace(tmp);
        } else {
Guolin Ke's avatar
Guolin Ke committed
150
          Log::Fatal("Could not find categorical_feature %s in data file", name.c_str());
151
152
153
        }
      }
    } else {
Guolin Ke's avatar
Guolin Ke committed
154
      for (auto token : Common::Split(config_.categorical_feature.c_str(), ',')) {
155
156
        int tmp = 0;
        if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
Guolin Ke's avatar
Guolin Ke committed
157
          Log::Fatal("categorical_feature is not a number,\n"
158
159
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
160
161
162
163
164
        }
        categorical_features_.emplace(tmp);
      }
    }
  }
Guolin Ke's avatar
Guolin Ke committed
165
166
}

167
168
169
170
171
172
173
174
175
176
void CheckSampleSize(size_t sample_cnt, size_t num_data) {
  if (static_cast<double>(sample_cnt) / num_data < 0.2f &&
      sample_cnt < 100000) {
    Log::Warning(
        "Using too small ``bin_construct_sample_cnt`` may encounter "
        "unexpected "
        "errors and poor accuracy.");
  }
}

177
Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_machines) {
Guolin Ke's avatar
Guolin Ke committed
178
  // don't support query id in data file when training in parallel
Guolin Ke's avatar
Guolin Ke committed
179
  if (num_machines > 1 && !config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
180
    if (group_idx_ > 0) {
181
182
      Log::Fatal("Using a query id without pre-partitioning the data file is not supported for parallel training.\n"
                 "Please use an additional query file or pre-partition the data");
Guolin Ke's avatar
Guolin Ke committed
183
184
    }
  }
Guolin Ke's avatar
Guolin Ke committed
185
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
Guolin Ke's avatar
Guolin Ke committed
186
187
  data_size_t num_global_data = 0;
  std::vector<data_size_t> used_data_indices;
188
189
  auto bin_filename = CheckCanLoadFromBin(filename);
  if (bin_filename.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
190
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_));
Guolin Ke's avatar
Guolin Ke committed
191
192
193
194
    if (parser == nullptr) {
      Log::Fatal("Could not recognize data format of %s", filename);
    }
    dataset->data_filename_ = filename;
Guolin Ke's avatar
Guolin Ke committed
195
    dataset->label_idx_ = label_idx_;
196
    dataset->metadata_.Init(filename);
Guolin Ke's avatar
Guolin Ke committed
197
    if (!config_.two_round) {
Guolin Ke's avatar
Guolin Ke committed
198
      // read data to memory
199
      auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
200
201
202
      dataset->num_data_ = static_cast<data_size_t>(text_data.size());
      // sample data
      auto sample_data = SampleTextDataFromMemory(text_data);
203
204
      CheckSampleSize(sample_data.size(),
                      static_cast<size_t>(dataset->num_data_));
Guolin Ke's avatar
Guolin Ke committed
205
      // construct feature bin mappers
Guolin Ke's avatar
Guolin Ke committed
206
      ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
207
      // initialize label
208
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
209
      // extract features
Guolin Ke's avatar
Guolin Ke committed
210
      ExtractFeaturesFromMemory(&text_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
211
212
213
214
215
216
217
218
219
      text_data.clear();
    } else {
      // sample data from file
      auto sample_data = SampleTextDataFromFile(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
      if (used_data_indices.size() > 0) {
        dataset->num_data_ = static_cast<data_size_t>(used_data_indices.size());
      } else {
        dataset->num_data_ = num_global_data;
      }
220
221
      CheckSampleSize(sample_data.size(),
                      static_cast<size_t>(dataset->num_data_));
Guolin Ke's avatar
Guolin Ke committed
222
      // construct feature bin mappers
Guolin Ke's avatar
Guolin Ke committed
223
      ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
224
      // initialize label
225
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
226
      Log::Debug("Making second pass...");
Guolin Ke's avatar
Guolin Ke committed
227
      // extract features
Guolin Ke's avatar
Guolin Ke committed
228
      ExtractFeaturesFromFile(filename, parser.get(), used_data_indices, dataset.get());
Guolin Ke's avatar
Guolin Ke committed
229
230
231
    }
  } else {
    // load data from binary file
232
    dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), rank, num_machines, &num_global_data, &used_data_indices));
Guolin Ke's avatar
Guolin Ke committed
233
234
235
236
  }
  // check meta data
  dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);
  // need to check training data
Guolin Ke's avatar
Guolin Ke committed
237
238
  CheckDataset(dataset.get());
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
239
240
241
242
}



243
Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, const Dataset* train_data) {
Guolin Ke's avatar
Guolin Ke committed
244
245
  data_size_t num_global_data = 0;
  std::vector<data_size_t> used_data_indices;
Guolin Ke's avatar
Guolin Ke committed
246
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
247
248
  auto bin_filename = CheckCanLoadFromBin(filename);
  if (bin_filename.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
249
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_));
Guolin Ke's avatar
Guolin Ke committed
250
251
252
253
    if (parser == nullptr) {
      Log::Fatal("Could not recognize data format of %s", filename);
    }
    dataset->data_filename_ = filename;
Guolin Ke's avatar
Guolin Ke committed
254
    dataset->label_idx_ = label_idx_;
255
    dataset->metadata_.Init(filename);
Guolin Ke's avatar
Guolin Ke committed
256
    if (!config_.two_round) {
Guolin Ke's avatar
Guolin Ke committed
257
258
259
260
      // read data in memory
      auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, 0, 1, &num_global_data, &used_data_indices);
      dataset->num_data_ = static_cast<data_size_t>(text_data.size());
      // initialize label
261
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
262
      dataset->CreateValid(train_data);
Guolin Ke's avatar
Guolin Ke committed
263
      // extract features
Guolin Ke's avatar
Guolin Ke committed
264
      ExtractFeaturesFromMemory(&text_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
265
266
      text_data.clear();
    } else {
Guolin Ke's avatar
Guolin Ke committed
267
      TextReader<data_size_t> text_reader(filename, config_.header);
Guolin Ke's avatar
Guolin Ke committed
268
269
270
271
      // Get number of lines of data file
      dataset->num_data_ = static_cast<data_size_t>(text_reader.CountLine());
      num_global_data = dataset->num_data_;
      // initialize label
272
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
273
      dataset->CreateValid(train_data);
Guolin Ke's avatar
Guolin Ke committed
274
      // extract features
Guolin Ke's avatar
Guolin Ke committed
275
      ExtractFeaturesFromFile(filename, parser.get(), used_data_indices, dataset.get());
Guolin Ke's avatar
Guolin Ke committed
276
277
278
    }
  } else {
    // load data from binary file
279
    dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), 0, 1, &num_global_data, &used_data_indices));
Guolin Ke's avatar
Guolin Ke committed
280
281
282
283
  }
  // not need to check validation data
  // check meta data
  dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
284
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
285
286
}

287
288
289
Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* bin_filename,
                                        int rank, int num_machines, int* num_global_data,
                                        std::vector<data_size_t>* used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
290
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
291
  auto reader = VirtualFileReader::Make(bin_filename);
Guolin Ke's avatar
Guolin Ke committed
292
  dataset->data_filename_ = data_filename;
293
  if (!reader->Init()) {
Guolin Ke's avatar
Guolin Ke committed
294
295
296
297
298
    Log::Fatal("Could not read binary data from %s", bin_filename);
  }

  // buffer to read binary file
  size_t buffer_size = 16 * 1024 * 1024;
Guolin Ke's avatar
Guolin Ke committed
299
  auto buffer = std::vector<char>(buffer_size);
300

301
302
  // check token
  size_t size_of_token = std::strlen(Dataset::binary_file_token);
303
304
305
306
  size_t read_cnt = reader->Read(
      buffer.data(),
      VirtualFileWriter::AlignedSize(sizeof(char) * size_of_token));
  if (read_cnt < sizeof(char) * size_of_token) {
307
308
309
    Log::Fatal("Binary file error: token has the wrong size");
  }
  if (std::string(buffer.data()) != std::string(Dataset::binary_file_token)) {
310
    Log::Fatal("Input file is not LightGBM binary file");
311
  }
Guolin Ke's avatar
Guolin Ke committed
312
313

  // read size of header
314
  read_cnt = reader->Read(buffer.data(), sizeof(size_t));
Guolin Ke's avatar
Guolin Ke committed
315

316
  if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
317
318
319
    Log::Fatal("Binary file error: header has the wrong size");
  }

Guolin Ke's avatar
Guolin Ke committed
320
  size_t size_of_head = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
321
322
323
324

  // re-allocmate space if not enough
  if (size_of_head > buffer_size) {
    buffer_size = size_of_head;
Guolin Ke's avatar
Guolin Ke committed
325
    buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
326
327
  }
  // read header
328
  read_cnt = reader->Read(buffer.data(), size_of_head);
Guolin Ke's avatar
Guolin Ke committed
329
330
331
332
333

  if (read_cnt != size_of_head) {
    Log::Fatal("Binary file error: header is incorrect");
  }
  // get header
Guolin Ke's avatar
Guolin Ke committed
334
  const char* mem_ptr = buffer.data();
Guolin Ke's avatar
Guolin Ke committed
335
  dataset->num_data_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
336
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->num_data_));
Guolin Ke's avatar
Guolin Ke committed
337
  dataset->num_features_ = *(reinterpret_cast<const int*>(mem_ptr));
338
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->num_features_));
Guolin Ke's avatar
Guolin Ke committed
339
  dataset->num_total_features_ = *(reinterpret_cast<const int*>(mem_ptr));
340
341
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(dataset->num_total_features_));
Guolin Ke's avatar
Guolin Ke committed
342
  dataset->label_idx_ = *(reinterpret_cast<const int*>(mem_ptr));
343
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->label_idx_));
344
  dataset->max_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
345
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->max_bin_));
346
  dataset->bin_construct_sample_cnt_ = *(reinterpret_cast<const int*>(mem_ptr));
347
348
  mem_ptr += VirtualFileWriter::AlignedSize(
      sizeof(dataset->bin_construct_sample_cnt_));
349
  dataset->min_data_in_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
350
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->min_data_in_bin_));
351
  dataset->use_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
352
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->use_missing_));
353
  dataset->zero_as_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
354
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->zero_as_missing_));
Guolin Ke's avatar
Guolin Ke committed
355
356
  const int* tmp_feature_map = reinterpret_cast<const int*>(mem_ptr);
  dataset->used_feature_map_.clear();
Guolin Ke's avatar
Guolin Ke committed
357
  for (int i = 0; i < dataset->num_total_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
358
359
    dataset->used_feature_map_.push_back(tmp_feature_map[i]);
  }
360
361
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int) *
                                            dataset->num_total_features_);
Guolin Ke's avatar
Guolin Ke committed
362
363
  // num_groups
  dataset->num_groups_ = *(reinterpret_cast<const int*>(mem_ptr));
364
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->num_groups_));
Guolin Ke's avatar
Guolin Ke committed
365
366
367
368
369
370
  // real_feature_idx_
  const int* tmp_ptr_real_feature_idx_ = reinterpret_cast<const int*>(mem_ptr);
  dataset->real_feature_idx_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->real_feature_idx_.push_back(tmp_ptr_real_feature_idx_[i]);
  }
371
372
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * dataset->num_features_);
Guolin Ke's avatar
Guolin Ke committed
373
374
375
376
377
378
  // feature2group
  const int* tmp_ptr_feature2group = reinterpret_cast<const int*>(mem_ptr);
  dataset->feature2group_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->feature2group_.push_back(tmp_ptr_feature2group[i]);
  }
379
380
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * dataset->num_features_);
Guolin Ke's avatar
Guolin Ke committed
381
382
383
384
385
386
  // feature2subfeature
  const int* tmp_ptr_feature2subfeature = reinterpret_cast<const int*>(mem_ptr);
  dataset->feature2subfeature_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->feature2subfeature_.push_back(tmp_ptr_feature2subfeature[i]);
  }
387
388
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * dataset->num_features_);
Guolin Ke's avatar
Guolin Ke committed
389
390
391
392
393
394
395
396
397
398
399
  // group_bin_boundaries
  const uint64_t* tmp_ptr_group_bin_boundaries = reinterpret_cast<const uint64_t*>(mem_ptr);
  dataset->group_bin_boundaries_.clear();
  for (int i = 0; i < dataset->num_groups_ + 1; ++i) {
    dataset->group_bin_boundaries_.push_back(tmp_ptr_group_bin_boundaries[i]);
  }
  mem_ptr += sizeof(uint64_t) * (dataset->num_groups_ + 1);

  // group_feature_start_
  const int* tmp_ptr_group_feature_start = reinterpret_cast<const int*>(mem_ptr);
  dataset->group_feature_start_.clear();
400
  for (int i = 0; i < dataset->num_groups_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
401
402
    dataset->group_feature_start_.push_back(tmp_ptr_group_feature_start[i]);
  }
403
404
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * (dataset->num_groups_));
Guolin Ke's avatar
Guolin Ke committed
405
406
407
408
409
410
411

  // group_feature_cnt_
  const int* tmp_ptr_group_feature_cnt = reinterpret_cast<const int*>(mem_ptr);
  dataset->group_feature_cnt_.clear();
  for (int i = 0; i < dataset->num_groups_; ++i) {
    dataset->group_feature_cnt_.push_back(tmp_ptr_group_feature_cnt[i]);
  }
412
413
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * (dataset->num_groups_));
Guolin Ke's avatar
Guolin Ke committed
414

Belinda Trotta's avatar
Belinda Trotta committed
415
  if (!config_.max_bin_by_feature.empty()) {
416
417
    CHECK_EQ(static_cast<size_t>(dataset->num_total_features_), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
Belinda Trotta's avatar
Belinda Trotta committed
418
419
420
421
422
423
424
425
426
    dataset->max_bin_by_feature_.resize(dataset->num_total_features_);
    dataset->max_bin_by_feature_.assign(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end());
  } else {
    const int32_t* tmp_ptr_max_bin_by_feature = reinterpret_cast<const int32_t*>(mem_ptr);
    dataset->max_bin_by_feature_.clear();
    for (int i = 0; i < dataset->num_total_features_; ++i) {
      dataset->max_bin_by_feature_.push_back(tmp_ptr_max_bin_by_feature[i]);
    }
  }
427
428
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int32_t) *
                                            (dataset->num_total_features_));
Belinda Trotta's avatar
Belinda Trotta committed
429
430
431
432
  if (ArrayArgs<int32_t>::CheckAll(dataset->max_bin_by_feature_, -1)) {
    dataset->max_bin_by_feature_.clear();
  }

Guolin Ke's avatar
Guolin Ke committed
433
  // get feature names
Guolin Ke's avatar
Guolin Ke committed
434
  dataset->feature_names_.clear();
Guolin Ke's avatar
Guolin Ke committed
435
436
437
  // write feature names
  for (int i = 0; i < dataset->num_total_features_; ++i) {
    int str_len = *(reinterpret_cast<const int*>(mem_ptr));
438
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int));
Guolin Ke's avatar
Guolin Ke committed
439
    std::stringstream str_buf;
440
    auto tmp_arr = reinterpret_cast<const char*>(mem_ptr);
Guolin Ke's avatar
Guolin Ke committed
441
    for (int j = 0; j < str_len; ++j) {
442
      char tmp_char = tmp_arr[j];
Guolin Ke's avatar
Guolin Ke committed
443
444
      str_buf << tmp_char;
    }
445
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(char) * str_len);
Guolin Ke's avatar
Guolin Ke committed
446
    dataset->feature_names_.emplace_back(str_buf.str());
Guolin Ke's avatar
Guolin Ke committed
447
  }
448
449
450
451
  // get forced_bin_bounds_
  dataset->forced_bin_bounds_ = std::vector<std::vector<double>>(dataset->num_total_features_, std::vector<double>());
  for (int i = 0; i < dataset->num_total_features_; ++i) {
    int num_bounds = *(reinterpret_cast<const int*>(mem_ptr));
452
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int));
453
    dataset->forced_bin_bounds_[i] = std::vector<double>();
454
455
    const double* tmp_ptr_forced_bounds =
        reinterpret_cast<const double*>(mem_ptr);
456
457
458
459
460
461
    for (int j = 0; j < num_bounds; ++j) {
      double bound = tmp_ptr_forced_bounds[j];
      dataset->forced_bin_bounds_[i].push_back(bound);
    }
    mem_ptr += num_bounds * sizeof(double);
  }
Guolin Ke's avatar
Guolin Ke committed
462
463

  // read size of meta data
464
  read_cnt = reader->Read(buffer.data(), sizeof(size_t));
Guolin Ke's avatar
Guolin Ke committed
465

466
  if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
467
468
469
    Log::Fatal("Binary file error: meta data has the wrong size");
  }

Guolin Ke's avatar
Guolin Ke committed
470
  size_t size_of_metadata = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
471
472
473
474

  // re-allocate space if not enough
  if (size_of_metadata > buffer_size) {
    buffer_size = size_of_metadata;
Guolin Ke's avatar
Guolin Ke committed
475
    buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
476
477
  }
  //  read meta data
478
  read_cnt = reader->Read(buffer.data(), size_of_metadata);
Guolin Ke's avatar
Guolin Ke committed
479
480
481
482
483

  if (read_cnt != size_of_metadata) {
    Log::Fatal("Binary file error: meta data is incorrect");
  }
  // load meta data
Guolin Ke's avatar
Guolin Ke committed
484
  dataset->metadata_.LoadFromMemory(buffer.data());
Guolin Ke's avatar
Guolin Ke committed
485

486
487
  *num_global_data = dataset->num_data_;
  used_data_indices->clear();
Guolin Ke's avatar
Guolin Ke committed
488
  // sample local used data if need to partition
Guolin Ke's avatar
Guolin Ke committed
489
  if (num_machines > 1 && !config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
490
491
492
493
    const data_size_t* query_boundaries = dataset->metadata_.query_boundaries();
    if (query_boundaries == nullptr) {
      // if not contain query file, minimal sample unit is one record
      for (data_size_t i = 0; i < dataset->num_data_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
494
        if (random_.NextShort(0, num_machines) == rank) {
495
          used_data_indices->push_back(i);
Guolin Ke's avatar
Guolin Ke committed
496
497
498
499
500
501
502
503
504
        }
      }
    } else {
      // if contain query file, minimal sample unit is one query
      data_size_t num_queries = dataset->metadata_.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      for (data_size_t i = 0; i < dataset->num_data_; ++i) {
        if (qid >= num_queries) {
505
506
          Log::Fatal("Current query exceeds the range of the query file,\n"
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
507
508
509
510
        }
        if (i >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
511
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
512
513
514
515
516
            is_query_used = true;
          }
          ++qid;
        }
        if (is_query_used) {
517
          used_data_indices->push_back(i);
Guolin Ke's avatar
Guolin Ke committed
518
519
520
        }
      }
    }
521
    dataset->num_data_ = static_cast<data_size_t>((*used_data_indices).size());
Guolin Ke's avatar
Guolin Ke committed
522
  }
523
  dataset->metadata_.PartitionLabel(*used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
524
  // read feature data
Guolin Ke's avatar
Guolin Ke committed
525
  for (int i = 0; i < dataset->num_groups_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
526
    // read feature size
527
528
    read_cnt = reader->Read(buffer.data(), sizeof(size_t));
    if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
529
530
      Log::Fatal("Binary file error: feature %d has the wrong size", i);
    }
Guolin Ke's avatar
Guolin Ke committed
531
    size_t size_of_feature = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
532
533
534
    // re-allocate space if not enough
    if (size_of_feature > buffer_size) {
      buffer_size = size_of_feature;
Guolin Ke's avatar
Guolin Ke committed
535
      buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
536
537
    }

538
    read_cnt = reader->Read(buffer.data(), size_of_feature);
Guolin Ke's avatar
Guolin Ke committed
539
540
541
542

    if (read_cnt != size_of_feature) {
      Log::Fatal("Binary file error: feature %d is incorrect, read count: %d", i, read_cnt);
    }
Guolin Ke's avatar
Guolin Ke committed
543
    dataset->feature_groups_.emplace_back(std::unique_ptr<FeatureGroup>(
544
545
      new FeatureGroup(buffer.data(),
                       *num_global_data,
546
                       *used_data_indices)));
Guolin Ke's avatar
Guolin Ke committed
547
  }
Guolin Ke's avatar
Guolin Ke committed
548
  dataset->feature_groups_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
549
  dataset->is_finish_load_ = true;
Guolin Ke's avatar
Guolin Ke committed
550
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
551
552
}

553

554
555
556
Dataset* DatasetLoader::ConstructFromSampleData(double** sample_values,
                                                int** sample_indices, int num_col, const int* num_per_col,
                                                size_t total_sample_size, data_size_t num_data) {
557
  CheckSampleSize(total_sample_size, static_cast<size_t>(num_data));
558
559
560
561
562
  int num_total_features = num_col;
  if (Network::num_machines() > 1) {
    num_total_features = Network::GlobalSyncUpByMax(num_total_features);
  }
  std::vector<std::unique_ptr<BinMapper>> bin_mappers(num_total_features);
563
564
  // fill feature_names_ if not header
  if (feature_names_.empty()) {
565
    for (int i = 0; i < num_col; ++i) {
566
567
568
569
570
      std::stringstream str_buf;
      str_buf << "Column_" << i;
      feature_names_.push_back(str_buf.str());
    }
  }
Belinda Trotta's avatar
Belinda Trotta committed
571
  if (!config_.max_bin_by_feature.empty()) {
572
573
    CHECK_EQ(static_cast<size_t>(num_col), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
Belinda Trotta's avatar
Belinda Trotta committed
574
  }
575
576
577
578
579

  // get forced split
  std::string forced_bins_path = config_.forcedbins_filename;
  std::vector<std::vector<double>> forced_bin_bounds = DatasetLoader::GetForcedBins(forced_bins_path, num_col, categorical_features_);

Guolin Ke's avatar
Guolin Ke committed
580
  const data_size_t filter_cnt = static_cast<data_size_t>(
Guolin Ke's avatar
Guolin Ke committed
581
    static_cast<double>(config_.min_data_in_leaf * total_sample_size) / num_data);
582
583
584
585
586
587
588
589
590
591
592
593
594
  if (Network::num_machines() == 1) {
    // if only one machine, find bin locally
    OMP_INIT_EX();
    #pragma omp parallel for schedule(guided)
    for (int i = 0; i < num_col; ++i) {
      OMP_LOOP_EX_BEGIN();
      if (ignore_features_.count(i) > 0) {
        bin_mappers[i] = nullptr;
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(i)) {
        bin_type = BinType::CategoricalBin;
595
596
597
598
        bool feat_is_unconstrained = ((config_.monotone_constraints.size() == 0) || (config_.monotone_constraints[i] == 0));
        if (!feat_is_unconstrained) {
            Log::Fatal("The output cannot be monotone with respect to categorical features");
        }
599
600
      }
      bin_mappers[i].reset(new BinMapper());
Belinda Trotta's avatar
Belinda Trotta committed
601
602
      if (config_.max_bin_by_feature.empty()) {
        bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
603
                                config_.max_bin, config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter,
604
605
                                bin_type, config_.use_missing, config_.zero_as_missing,
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
606
607
      } else {
        bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
608
                                config_.max_bin_by_feature[i], config_.min_data_in_bin,
609
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
610
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
611
      }
612
613
614
615
616
617
618
619
620
621
622
623
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();
  } else {
    // if have multi-machines, need to find bin distributed
    // different machines will find bin for different features
    int num_machines = Network::num_machines();
    int rank = Network::rank();
    // start and len will store the process feature indices for different machines
    // machine i will find bins for features in [ start[i], start[i] + len[i] )
    std::vector<int> start(num_machines);
    std::vector<int> len(num_machines);
624
    int step = (num_total_features + num_machines - 1) / num_machines;
625
626
627
628
    if (step < 1) { step = 1; }

    start[0] = 0;
    for (int i = 0; i < num_machines - 1; ++i) {
629
      len[i] = std::min(step, num_total_features - start[i]);
630
631
      start[i + 1] = start[i] + len[i];
    }
632
    len[num_machines - 1] = num_total_features - start[num_machines - 1];
633
634
635
636
637
638
639
640
641
642
643
644
    OMP_INIT_EX();
    #pragma omp parallel for schedule(guided)
    for (int i = 0; i < len[rank]; ++i) {
      OMP_LOOP_EX_BEGIN();
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(start[rank] + i)) {
        bin_type = BinType::CategoricalBin;
      }
      bin_mappers[i].reset(new BinMapper());
645
646
647
      if (num_col <= start[rank] + i) {
        continue;
      }
Belinda Trotta's avatar
Belinda Trotta committed
648
      if (config_.max_bin_by_feature.empty()) {
649
650
        bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
                                total_sample_size, config_.max_bin, config_.min_data_in_bin,
651
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
652
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
653
      } else {
654
655
        bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
                                total_sample_size, config_.max_bin_by_feature[start[rank] + i],
656
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
657
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
658
      }
659
660
      OMP_LOOP_EX_END();
    }
Guolin Ke's avatar
Guolin Ke committed
661
    OMP_THROW_EX();
662
    comm_size_t self_buf_size = 0;
663
    for (int i = 0; i < len[rank]; ++i) {
664
665
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
666
      }
667
      self_buf_size += static_cast<comm_size_t>(bin_mappers[i]->SizesInByte());
Guolin Ke's avatar
Guolin Ke committed
668
    }
669
670
    std::vector<char> input_buffer(self_buf_size);
    auto cp_ptr = input_buffer.data();
671
672
673
674
    for (int i = 0; i < len[rank]; ++i) {
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
675
676
      bin_mappers[i]->CopyTo(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
677
678
679
      // free
      bin_mappers[i].reset(nullptr);
    }
680
681
682
683
    std::vector<comm_size_t> size_len = Network::GlobalArray(self_buf_size);
    std::vector<comm_size_t> size_start(num_machines, 0);
    for (int i = 1; i < num_machines; ++i) {
      size_start[i] = size_start[i - 1] + size_len[i - 1];
684
    }
685
686
    comm_size_t total_buffer_size = size_start[num_machines - 1] + size_len[num_machines - 1];
    std::vector<char> output_buffer(total_buffer_size);
687
    // gather global feature bin mappers
688
689
    Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), total_buffer_size);
    cp_ptr = output_buffer.data();
690
    // restore features bins from buffer
691
    for (int i = 0; i < num_total_features; ++i) {
692
693
694
695
696
      if (ignore_features_.count(i) > 0) {
        bin_mappers[i] = nullptr;
        continue;
      }
      bin_mappers[i].reset(new BinMapper());
697
698
      bin_mappers[i]->CopyFrom(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
699
    }
Guolin Ke's avatar
Guolin Ke committed
700
  }
Guolin Ke's avatar
Guolin Ke committed
701
  auto dataset = std::unique_ptr<Dataset>(new Dataset(num_data));
Guolin Ke's avatar
Guolin Ke committed
702
  dataset->Construct(&bin_mappers, num_total_features, forced_bin_bounds, sample_indices, sample_values, num_per_col, num_col, total_sample_size, config_);
703
  dataset->set_feature_names(feature_names_);
Guolin Ke's avatar
Guolin Ke committed
704
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
705
}
Guolin Ke's avatar
Guolin Ke committed
706
707
708
709
710
711


// ---- private functions ----

void DatasetLoader::CheckDataset(const Dataset* dataset) {
  if (dataset->num_data_ <= 0) {
Guolin Ke's avatar
Guolin Ke committed
712
    Log::Fatal("Data file %s is empty", dataset->data_filename_.c_str());
Guolin Ke's avatar
Guolin Ke committed
713
  }
714
715
  if (dataset->feature_names_.size() != static_cast<size_t>(dataset->num_total_features_)) {
    Log::Fatal("Size of feature name error, should be %d, got %d", dataset->num_total_features_,
716
               static_cast<int>(dataset->feature_names_.size()));
717
  }
Guolin Ke's avatar
Guolin Ke committed
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
  bool is_feature_order_by_group = true;
  int last_group = -1;
  int last_sub_feature = -1;
  // if features are ordered, not need to use hist_buf
  for (int i = 0; i < dataset->num_features_; ++i) {
    int group = dataset->feature2group_[i];
    int sub_feature = dataset->feature2subfeature_[i];
    if (group < last_group) {
      is_feature_order_by_group = false;
    } else if (group == last_group) {
      if (sub_feature <= last_sub_feature) {
        is_feature_order_by_group = false;
        break;
      }
    }
    last_group = group;
    last_sub_feature = sub_feature;
  }
  if (!is_feature_order_by_group) {
737
    Log::Fatal("Features in dataset should be ordered by group");
Guolin Ke's avatar
Guolin Ke committed
738
  }
Guolin Ke's avatar
Guolin Ke committed
739
740
741
}

std::vector<std::string> DatasetLoader::LoadTextDataToMemory(const char* filename, const Metadata& metadata,
742
743
                                                             int rank, int num_machines, int* num_global_data,
                                                             std::vector<data_size_t>* used_data_indices) {
744
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
745
  used_data_indices->clear();
Guolin Ke's avatar
Guolin Ke committed
746
  if (num_machines == 1 || config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
747
748
749
750
751
752
753
754
755
    // read all lines
    *num_global_data = text_reader.ReadAllLines();
  } else {  // need partition data
            // get query data
    const data_size_t* query_boundaries = metadata.query_boundaries();

    if (query_boundaries == nullptr) {
      // if not contain query data, minimal sample unit is one record
      *num_global_data = text_reader.ReadAndFilterLines([this, rank, num_machines](data_size_t) {
Guolin Ke's avatar
Guolin Ke committed
756
        if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
757
758
759
760
761
762
763
764
765
766
767
768
769
770
          return true;
        } else {
          return false;
        }
      }, used_data_indices);
    } else {
      // if contain query data, minimal sample unit is one query
      data_size_t num_queries = metadata.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      *num_global_data = text_reader.ReadAndFilterLines(
        [this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries]
      (data_size_t line_idx) {
        if (qid >= num_queries) {
771
772
          Log::Fatal("Current query exceeds the range of the query file,\n"
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
773
774
775
776
        }
        if (line_idx >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
777
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
778
779
780
781
782
783
784
785
786
787
788
789
            is_query_used = true;
          }
          ++qid;
        }
        return is_query_used;
      }, used_data_indices);
    }
  }
  return std::move(text_reader.Lines());
}

std::vector<std::string> DatasetLoader::SampleTextDataFromMemory(const std::vector<std::string>& data) {
Guolin Ke's avatar
Guolin Ke committed
790
  int sample_cnt = config_.bin_construct_sample_cnt;
791
792
  if (static_cast<size_t>(sample_cnt) > data.size()) {
    sample_cnt = static_cast<int>(data.size());
793
  }
794
  auto sample_indices = random_.Sample(static_cast<int>(data.size()), sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
795
  std::vector<std::string> out(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
796
797
  for (size_t i = 0; i < sample_indices.size(); ++i) {
    const size_t idx = sample_indices[i];
Guolin Ke's avatar
Guolin Ke committed
798
    out[i] = data[idx];
Guolin Ke's avatar
Guolin Ke committed
799
800
801
802
  }
  return out;
}

803
804
805
std::vector<std::string> DatasetLoader::SampleTextDataFromFile(const char* filename, const Metadata& metadata,
                                                               int rank, int num_machines, int* num_global_data,
                                                               std::vector<data_size_t>* used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
806
  const data_size_t sample_cnt = static_cast<data_size_t>(config_.bin_construct_sample_cnt);
807
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
808
  std::vector<std::string> out_data;
Guolin Ke's avatar
Guolin Ke committed
809
  if (num_machines == 1 || config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
810
    *num_global_data = static_cast<data_size_t>(text_reader.SampleFromFile(&random_, sample_cnt, &out_data));
Guolin Ke's avatar
Guolin Ke committed
811
812
813
814
815
816
817
  } else {  // need partition data
            // get query data
    const data_size_t* query_boundaries = metadata.query_boundaries();
    if (query_boundaries == nullptr) {
      // if not contain query file, minimal sample unit is one record
      *num_global_data = text_reader.SampleAndFilterFromFile([this, rank, num_machines]
      (data_size_t) {
Guolin Ke's avatar
Guolin Ke committed
818
        if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
819
820
821
822
          return true;
        } else {
          return false;
        }
Guolin Ke's avatar
Guolin Ke committed
823
      }, used_data_indices, &random_, sample_cnt, &out_data);
Guolin Ke's avatar
Guolin Ke committed
824
825
826
827
828
829
830
831
832
    } else {
      // if contain query file, minimal sample unit is one query
      data_size_t num_queries = metadata.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      *num_global_data = text_reader.SampleAndFilterFromFile(
        [this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries]
      (data_size_t line_idx) {
        if (qid >= num_queries) {
833
834
          Log::Fatal("Query id exceeds the range of the query file, "
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
835
836
837
838
        }
        if (line_idx >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
839
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
840
841
842
843
844
            is_query_used = true;
          }
          ++qid;
        }
        return is_query_used;
Guolin Ke's avatar
Guolin Ke committed
845
      }, used_data_indices, &random_, sample_cnt, &out_data);
Guolin Ke's avatar
Guolin Ke committed
846
847
848
849
850
    }
  }
  return out_data;
}

851
852
853
void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
                                                    const std::vector<std::string>& sample_data,
                                                    const Parser* parser, Dataset* dataset) {
Guolin Ke's avatar
Guolin Ke committed
854
  std::vector<std::vector<double>> sample_values;
Guolin Ke's avatar
Guolin Ke committed
855
  std::vector<std::vector<int>> sample_indices;
Guolin Ke's avatar
Guolin Ke committed
856
857
  std::vector<std::pair<int, double>> oneline_features;
  double label;
Guolin Ke's avatar
Guolin Ke committed
858
  for (int i = 0; i < static_cast<int>(sample_data.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
859
860
861
862
    oneline_features.clear();
    // parse features
    parser->ParseOneLine(sample_data[i].c_str(), &oneline_features, &label);
    for (std::pair<int, double>& inner_data : oneline_features) {
863
      if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
Guolin Ke's avatar
Guolin Ke committed
864
865
        sample_values.resize(inner_data.first + 1);
        sample_indices.resize(inner_data.first + 1);
866
      }
Guolin Ke's avatar
Guolin Ke committed
867
      if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
Guolin Ke's avatar
Guolin Ke committed
868
869
        sample_values[inner_data.first].emplace_back(inner_data.second);
        sample_indices[inner_data.first].emplace_back(i);
Guolin Ke's avatar
Guolin Ke committed
870
871
872
873
      }
    }
  }

Guolin Ke's avatar
Guolin Ke committed
874
  dataset->feature_groups_.clear();
875
876
877
878
879
  dataset->num_total_features_ = std::max(static_cast<int>(sample_values.size()), parser->NumFeatures());
  if (num_machines > 1) {
    dataset->num_total_features_ = Network::GlobalSyncUpByMax(dataset->num_total_features_);
  }
  if (!feature_names_.empty()) {
880
    CHECK_EQ(dataset->num_total_features_, static_cast<int>(feature_names_.size()));
881
  }
Guolin Ke's avatar
Guolin Ke committed
882

Belinda Trotta's avatar
Belinda Trotta committed
883
  if (!config_.max_bin_by_feature.empty()) {
884
885
    CHECK_EQ(static_cast<size_t>(dataset->num_total_features_), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
Belinda Trotta's avatar
Belinda Trotta committed
886
887
  }

888
889
  // get forced split
  std::string forced_bins_path = config_.forcedbins_filename;
890
891
  std::vector<std::vector<double>> forced_bin_bounds = DatasetLoader::GetForcedBins(forced_bins_path,
                                                                                    dataset->num_total_features_,
892
893
                                                                                    categorical_features_);

Guolin Ke's avatar
Guolin Ke committed
894
895
896
897
898
899
  // check the range of label_idx, weight_idx and group_idx
  CHECK(label_idx_ >= 0 && label_idx_ <= dataset->num_total_features_);
  CHECK(weight_idx_ < 0 || weight_idx_ < dataset->num_total_features_);
  CHECK(group_idx_ < 0 || group_idx_ < dataset->num_total_features_);

  // fill feature_names_ if not header
Guolin Ke's avatar
Guolin Ke committed
900
  if (feature_names_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
901
902
903
904
905
906
    for (int i = 0; i < dataset->num_total_features_; ++i) {
      std::stringstream str_buf;
      str_buf << "Column_" << i;
      feature_names_.push_back(str_buf.str());
    }
  }
907
  dataset->set_feature_names(feature_names_);
Guolin Ke's avatar
Guolin Ke committed
908
  std::vector<std::unique_ptr<BinMapper>> bin_mappers(dataset->num_total_features_);
Guolin Ke's avatar
Guolin Ke committed
909
  const data_size_t filter_cnt = static_cast<data_size_t>(
Guolin Ke's avatar
Guolin Ke committed
910
    static_cast<double>(config_.min_data_in_leaf* sample_data.size()) / dataset->num_data_);
Guolin Ke's avatar
Guolin Ke committed
911
912
913
  // start find bins
  if (num_machines == 1) {
    // if only one machine, find bin locally
914
    OMP_INIT_EX();
915
    #pragma omp parallel for schedule(guided)
Guolin Ke's avatar
Guolin Ke committed
916
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
917
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
918
      if (ignore_features_.count(i) > 0) {
Guolin Ke's avatar
Guolin Ke committed
919
        bin_mappers[i] = nullptr;
Guolin Ke's avatar
Guolin Ke committed
920
921
        continue;
      }
922
923
924
925
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(i)) {
        bin_type = BinType::CategoricalBin;
      }
Guolin Ke's avatar
Guolin Ke committed
926
      bin_mappers[i].reset(new BinMapper());
Belinda Trotta's avatar
Belinda Trotta committed
927
928
      if (config_.max_bin_by_feature.empty()) {
        bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
929
                                sample_data.size(), config_.max_bin, config_.min_data_in_bin,
930
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
931
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
932
933
      } else {
        bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
934
                                sample_data.size(), config_.max_bin_by_feature[i],
935
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
936
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
937
      }
938
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
939
    }
940
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
941
942
  } else {
    // start and len will store the process feature indices for different machines
943
    // machine i will find bins for features in [ start[i], start[i] + len[i] )
Guolin Ke's avatar
Guolin Ke committed
944
945
    std::vector<int> start(num_machines);
    std::vector<int> len(num_machines);
946
    int step = (dataset->num_total_features_ + num_machines - 1) / num_machines;
Guolin Ke's avatar
Guolin Ke committed
947
948
949
950
    if (step < 1) { step = 1; }

    start[0] = 0;
    for (int i = 0; i < num_machines - 1; ++i) {
951
      len[i] = std::min(step, dataset->num_total_features_ - start[i]);
Guolin Ke's avatar
Guolin Ke committed
952
953
      start[i + 1] = start[i] + len[i];
    }
954
    len[num_machines - 1] = dataset->num_total_features_ - start[num_machines - 1];
955
    OMP_INIT_EX();
956
    #pragma omp parallel for schedule(guided)
957
    for (int i = 0; i < len[rank]; ++i) {
958
      OMP_LOOP_EX_BEGIN();
959
960
961
962
963
964
965
966
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(start[rank] + i)) {
        bin_type = BinType::CategoricalBin;
      }
      bin_mappers[i].reset(new BinMapper());
Nikita Titov's avatar
Nikita Titov committed
967
      if (static_cast<int>(sample_values.size()) <= start[rank] + i) {
968
969
        continue;
      }
Belinda Trotta's avatar
Belinda Trotta committed
970
      if (config_.max_bin_by_feature.empty()) {
971
        bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
Belinda Trotta's avatar
Belinda Trotta committed
972
                                static_cast<int>(sample_values[start[rank] + i].size()),
973
                                sample_data.size(), config_.max_bin, config_.min_data_in_bin,
974
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
975
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
976
      } else {
977
        bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
Belinda Trotta's avatar
Belinda Trotta committed
978
                                static_cast<int>(sample_values[start[rank] + i].size()),
979
                                sample_data.size(), config_.max_bin_by_feature[i],
980
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type,
981
                                config_.use_missing, config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
982
      }
983
      OMP_LOOP_EX_END();
984
    }
985
    OMP_THROW_EX();
986
    comm_size_t self_buf_size = 0;
Guolin Ke's avatar
Guolin Ke committed
987
    for (int i = 0; i < len[rank]; ++i) {
988
989
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
Guolin Ke's avatar
Guolin Ke committed
990
      }
991
      self_buf_size += static_cast<comm_size_t>(bin_mappers[i]->SizesInByte());
Guolin Ke's avatar
Guolin Ke committed
992
    }
993
994
    std::vector<char> input_buffer(self_buf_size);
    auto cp_ptr = input_buffer.data();
Guolin Ke's avatar
Guolin Ke committed
995
    for (int i = 0; i < len[rank]; ++i) {
Guolin Ke's avatar
Guolin Ke committed
996
997
998
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
999
1000
      bin_mappers[i]->CopyTo(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
1001
1002
      // free
      bin_mappers[i].reset(nullptr);
Guolin Ke's avatar
Guolin Ke committed
1003
    }
1004
1005
1006
1007
    std::vector<comm_size_t> size_len = Network::GlobalArray(self_buf_size);
    std::vector<comm_size_t> size_start(num_machines, 0);
    for (int i = 1; i < num_machines; ++i) {
      size_start[i] = size_start[i - 1] + size_len[i - 1];
Guolin Ke's avatar
Guolin Ke committed
1008
    }
1009
1010
    comm_size_t total_buffer_size = size_start[num_machines - 1] + size_len[num_machines - 1];
    std::vector<char> output_buffer(total_buffer_size);
Guolin Ke's avatar
Guolin Ke committed
1011
    // gather global feature bin mappers
1012
1013
    Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), total_buffer_size);
    cp_ptr = output_buffer.data();
Guolin Ke's avatar
Guolin Ke committed
1014
    // restore features bins from buffer
1015
    for (int i = 0; i < dataset->num_total_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1016
      if (ignore_features_.count(i) > 0) {
Guolin Ke's avatar
Guolin Ke committed
1017
        bin_mappers[i] = nullptr;
Guolin Ke's avatar
Guolin Ke committed
1018
1019
        continue;
      }
Guolin Ke's avatar
Guolin Ke committed
1020
      bin_mappers[i].reset(new BinMapper());
1021
1022
      bin_mappers[i]->CopyFrom(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
Guolin Ke's avatar
Guolin Ke committed
1023
1024
    }
  }
1025
  dataset->Construct(&bin_mappers, dataset->num_total_features_, forced_bin_bounds, Common::Vector2Ptr<int>(&sample_indices).data(),
Guolin Ke's avatar
Guolin Ke committed
1026
                     Common::Vector2Ptr<double>(&sample_values).data(),
1027
                     Common::VectorSize<int>(sample_indices).data(), static_cast<int>(sample_indices.size()), sample_data.size(), config_);
Guolin Ke's avatar
Guolin Ke committed
1028
1029
1030
}

/*! \brief Extract local features from memory */
Guolin Ke's avatar
Guolin Ke committed
1031
void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>* text_data, const Parser* parser, Dataset* dataset) {
Guolin Ke's avatar
Guolin Ke committed
1032
1033
  std::vector<std::pair<int, double>> oneline_features;
  double tmp_label = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
1034
  auto& ref_text_data = *text_data;
Guolin Ke's avatar
Guolin Ke committed
1035
  if (predict_fun_ == nullptr) {
1036
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1037
    // if doesn't need to prediction with initial model
1038
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
Guolin Ke's avatar
Guolin Ke committed
1039
    for (data_size_t i = 0; i < dataset->num_data_; ++i) {
1040
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1041
1042
1043
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
Guolin Ke's avatar
Guolin Ke committed
1044
      parser->ParseOneLine(ref_text_data[i].c_str(), &oneline_features, &tmp_label);
Guolin Ke's avatar
Guolin Ke committed
1045
      // set label
1046
      dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1047
      // free processed line:
Guolin Ke's avatar
Guolin Ke committed
1048
      ref_text_data[i].clear();
Guolin Ke's avatar
Guolin Ke committed
1049
1050
      // shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
      // text_reader_->Lines()[i].shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
1051
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1052
1053
      // push data
      for (auto& inner_data : oneline_features) {
1054
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1055
1056
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1057
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1058
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1059
1060
1061
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
          dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
Guolin Ke's avatar
Guolin Ke committed
1062
1063
        } else {
          if (inner_data.first == weight_idx_) {
1064
            dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1065
1066
1067
1068
1069
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1070
      dataset->FinishOneRow(tid, i, is_feature_added);
1071
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1072
    }
1073
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1074
  } else {
1075
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1076
    // if need to prediction with initial model
1077
    std::vector<double> init_score(dataset->num_data_ * num_class_);
1078
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
Guolin Ke's avatar
Guolin Ke committed
1079
    for (data_size_t i = 0; i < dataset->num_data_; ++i) {
1080
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1081
1082
1083
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
Guolin Ke's avatar
Guolin Ke committed
1084
      parser->ParseOneLine(ref_text_data[i].c_str(), &oneline_features, &tmp_label);
Guolin Ke's avatar
Guolin Ke committed
1085
      // set initial score
Guolin Ke's avatar
Guolin Ke committed
1086
1087
      std::vector<double> oneline_init_score(num_class_);
      predict_fun_(oneline_features, oneline_init_score.data());
1088
      for (int k = 0; k < num_class_; ++k) {
1089
        init_score[k * dataset->num_data_ + i] = static_cast<double>(oneline_init_score[k]);
Guolin Ke's avatar
Guolin Ke committed
1090
1091
      }
      // set label
1092
      dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1093
1094
1095
1096
1097
      // free processed line:
      text_data[i].clear();
      // shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
      // text_reader_->Lines()[i].shrink_to_fit();
      // push data
Guolin Ke's avatar
Guolin Ke committed
1098
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1099
      for (auto& inner_data : oneline_features) {
1100
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1101
1102
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1103
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1104
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1105
1106
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
1107
          dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
Guolin Ke's avatar
Guolin Ke committed
1108
1109
        } else {
          if (inner_data.first == weight_idx_) {
1110
            dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1111
1112
1113
1114
1115
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1116
      dataset->FinishOneRow(tid, i, is_feature_added);
1117
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1118
    }
1119
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1120
    // metadata_ will manage space of init_score
1121
    dataset->metadata_.SetInitScore(init_score.data(), dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1122
  }
Guolin Ke's avatar
Guolin Ke committed
1123
  dataset->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1124
  // text data can be free after loaded feature values
Guolin Ke's avatar
Guolin Ke committed
1125
  text_data->clear();
Guolin Ke's avatar
Guolin Ke committed
1126
1127
1128
}

/*! \brief Extract local features from file */
1129
1130
void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser* parser,
                                            const std::vector<data_size_t>& used_data_indices, Dataset* dataset) {
1131
  std::vector<double> init_score;
Guolin Ke's avatar
Guolin Ke committed
1132
  if (predict_fun_ != nullptr) {
1133
    init_score = std::vector<double>(dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1134
1135
1136
1137
1138
1139
  }
  std::function<void(data_size_t, const std::vector<std::string>&)> process_fun =
    [this, &init_score, &parser, &dataset]
  (data_size_t start_idx, const std::vector<std::string>& lines) {
    std::vector<std::pair<int, double>> oneline_features;
    double tmp_label = 0.0f;
1140
    OMP_INIT_EX();
1141
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
Guolin Ke's avatar
Guolin Ke committed
1142
    for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) {
1143
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1144
1145
1146
1147
1148
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
      parser->ParseOneLine(lines[i].c_str(), &oneline_features, &tmp_label);
      // set initial score
Guolin Ke's avatar
Guolin Ke committed
1149
      if (!init_score.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1150
1151
        std::vector<double> oneline_init_score(num_class_);
        predict_fun_(oneline_features, oneline_init_score.data());
1152
        for (int k = 0; k < num_class_; ++k) {
1153
          init_score[k * dataset->num_data_ + start_idx + i] = static_cast<double>(oneline_init_score[k]);
Guolin Ke's avatar
Guolin Ke committed
1154
1155
1156
        }
      }
      // set label
1157
      dataset->metadata_.SetLabelAt(start_idx + i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1158
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1159
1160
      // push data
      for (auto& inner_data : oneline_features) {
1161
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1162
1163
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1164
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1165
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1166
1167
1168
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
          dataset->feature_groups_[group]->PushData(tid, sub_feature, start_idx + i, inner_data.second);
Guolin Ke's avatar
Guolin Ke committed
1169
1170
        } else {
          if (inner_data.first == weight_idx_) {
1171
            dataset->metadata_.SetWeightAt(start_idx + i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1172
1173
1174
1175
1176
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(start_idx + i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1177
      dataset->FinishOneRow(tid, i, is_feature_added);
1178
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1179
    }
1180
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1181
  };
1182
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
1183
  if (!used_data_indices.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1184
1185
1186
1187
1188
1189
1190
1191
    // only need part of data
    text_reader.ReadPartAndProcessParallel(used_data_indices, process_fun);
  } else {
    // need full data
    text_reader.ReadAllAndProcessParallel(process_fun);
  }

  // metadata_ will manage space of init_score
Guolin Ke's avatar
Guolin Ke committed
1192
  if (!init_score.empty()) {
1193
    dataset->metadata_.SetInitScore(init_score.data(), dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1194
  }
Guolin Ke's avatar
Guolin Ke committed
1195
  dataset->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1196
1197
1198
}

/*! \brief Check can load from binary file */
1199
std::string DatasetLoader::CheckCanLoadFromBin(const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
1200
1201
1202
  std::string bin_filename(filename);
  bin_filename.append(".bin");

1203
  auto reader = VirtualFileReader::Make(bin_filename.c_str());
Guolin Ke's avatar
Guolin Ke committed
1204

1205
  if (!reader->Init()) {
1206
    bin_filename = std::string(filename);
1207
1208
    reader = VirtualFileReader::Make(bin_filename.c_str());
    if (!reader->Init()) {
1209
      Log::Fatal("Cannot open data file %s", bin_filename.c_str());
1210
    }
1211
  }
1212
1213
1214
1215
1216

  size_t buffer_size = 256;
  auto buffer = std::vector<char>(buffer_size);
  // read size of token
  size_t size_of_token = std::strlen(Dataset::binary_file_token);
1217
  size_t read_cnt = reader->Read(buffer.data(), size_of_token);
1218
1219
  if (read_cnt == size_of_token
      && std::string(buffer.data()) == std::string(Dataset::binary_file_token)) {
1220
    return bin_filename;
Guolin Ke's avatar
Guolin Ke committed
1221
  } else {
1222
    return std::string();
Guolin Ke's avatar
Guolin Ke committed
1223
1224
1225
  }
}

1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238


std::vector<std::vector<double>> DatasetLoader::GetForcedBins(std::string forced_bins_path, int num_total_features,
                                                              const std::unordered_set<int>& categorical_features) {
  std::vector<std::vector<double>> forced_bins(num_total_features, std::vector<double>());
  if (forced_bins_path != "") {
    std::ifstream forced_bins_stream(forced_bins_path.c_str());
    if (forced_bins_stream.fail()) {
      Log::Warning("Could not open %s. Will ignore.", forced_bins_path.c_str());
    } else {
      std::stringstream buffer;
      buffer << forced_bins_stream.rdbuf();
      std::string err;
Guolin Ke's avatar
Guolin Ke committed
1239
      Json forced_bins_json = Json::parse(buffer.str(), &err);
1240
1241
1242
1243
      CHECK(forced_bins_json.is_array());
      std::vector<Json> forced_bins_arr = forced_bins_json.array_items();
      for (size_t i = 0; i < forced_bins_arr.size(); ++i) {
        int feature_num = forced_bins_arr[i]["feature"].int_value();
Nikita Titov's avatar
Nikita Titov committed
1244
        CHECK_LT(feature_num, num_total_features);
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
        if (categorical_features.count(feature_num)) {
          Log::Warning("Feature %d is categorical. Will ignore forced bins for this  feature.", feature_num);
        } else {
          std::vector<Json> bounds_arr = forced_bins_arr[i]["bin_upper_bound"].array_items();
          for (size_t j = 0; j < bounds_arr.size(); ++j) {
            forced_bins[feature_num].push_back(bounds_arr[j].number_value());
          }
        }
      }
      // remove duplicates
      for (int i = 0; i < num_total_features; ++i) {
        auto new_end = std::unique(forced_bins[i].begin(), forced_bins[i].end());
        forced_bins[i].erase(new_end, forced_bins[i].end());
      }
    }
  }
  return forced_bins;
}

1264
}  // namespace LightGBM