dataset_loader.cpp 62 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
#include <LightGBM/dataset_loader.h>

Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/network.h>
8
#include <LightGBM/utils/array_args.h>
9
#include <LightGBM/utils/json11.h>
10
11
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
12

13
#include <chrono>
14
15
#include <fstream>

Guolin Ke's avatar
Guolin Ke committed
16
17
namespace LightGBM {

18
19
using json11::Json;

Guolin Ke's avatar
Guolin Ke committed
20
21
DatasetLoader::DatasetLoader(const Config& io_config, const PredictFunction& predict_fun, int num_class, const char* filename)
  :config_(io_config), random_(config_.data_random_seed), predict_fun_(predict_fun), num_class_(num_class) {
Guolin Ke's avatar
Guolin Ke committed
22
23
24
25
  label_idx_ = 0;
  weight_idx_ = NO_SPECIFIC;
  group_idx_ = NO_SPECIFIC;
  SetHeader(filename);
26
27
28
29
  store_raw_ = false;
  if (io_config.linear_tree) {
    store_raw_ = true;
  }
Guolin Ke's avatar
Guolin Ke committed
30
31
32
33
34
}

DatasetLoader::~DatasetLoader() {
}

Guolin Ke's avatar
Guolin Ke committed
35
void DatasetLoader::SetHeader(const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
36
  std::unordered_map<std::string, int> name2idx;
Guolin Ke's avatar
Guolin Ke committed
37
  std::string name_prefix("name:");
38
  if (filename != nullptr && CheckCanLoadFromBin(filename) == "") {
Guolin Ke's avatar
Guolin Ke committed
39
    TextReader<data_size_t> text_reader(filename, config_.header);
Guolin Ke's avatar
Guolin Ke committed
40

Guolin Ke's avatar
Guolin Ke committed
41
    // get column names
Guolin Ke's avatar
Guolin Ke committed
42
    if (config_.header) {
Guolin Ke's avatar
Guolin Ke committed
43
      std::string first_line = text_reader.first_line();
44
      feature_names_ = Common::Split(first_line.c_str(), "\t,");
45
46
47
48
49
50
51
52
53
54
55
56
    } else if (!config_.parser_config_file.empty()) {
      // support to get header from parser config, so could utilize following label name to id mapping logic.
      TextReader<data_size_t> parser_config_reader(config_.parser_config_file.c_str(), false);
      parser_config_reader.ReadAllLines();
      std::string parser_config_str = parser_config_reader.JoinedLines();
      if (!parser_config_str.empty()) {
        std::string header_in_parser_config = Common::GetFromParserConfig(parser_config_str, "header");
        if (!header_in_parser_config.empty()) {
          Log::Info("Get raw column names from parser config.");
          feature_names_ = Common::Split(header_in_parser_config.c_str(), "\t,");
        }
      }
Guolin Ke's avatar
Guolin Ke committed
57
58
    }

Guolin Ke's avatar
Guolin Ke committed
59
    // load label idx first
Guolin Ke's avatar
Guolin Ke committed
60
61
62
    if (config_.label_column.size() > 0) {
      if (Common::StartsWith(config_.label_column, name_prefix)) {
        std::string name = config_.label_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
63
64
65
66
67
68
69
70
71
72
        label_idx_ = -1;
        for (int i = 0; i < static_cast<int>(feature_names_.size()); ++i) {
          if (name == feature_names_[i]) {
            label_idx_ = i;
            break;
          }
        }
        if (label_idx_ >= 0) {
          Log::Info("Using column %s as label", name.c_str());
        } else {
73
74
          Log::Fatal("Could not find label column %s in data file \n"
                     "or data file doesn't contain header", name.c_str());
Guolin Ke's avatar
Guolin Ke committed
75
        }
Guolin Ke's avatar
Guolin Ke committed
76
      } else {
Guolin Ke's avatar
Guolin Ke committed
77
        if (!Common::AtoiAndCheck(config_.label_column.c_str(), &label_idx_)) {
78
79
80
          Log::Fatal("label_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
81
82
        }
        Log::Info("Using column number %d as label", label_idx_);
Guolin Ke's avatar
Guolin Ke committed
83
84
      }
    }
Guolin Ke's avatar
Guolin Ke committed
85

86
87
88
89
90
91
92
93
94
    if (!config_.parser_config_file.empty()) {
      // if parser config file exists, feature names may be changed after customized parser applied.
      // clear here so could use default filled feature names during dataset construction.
      // may improve by saving real feature names defined in parser in the future.
      if (!feature_names_.empty()) {
        feature_names_.clear();
      }
    }

Guolin Ke's avatar
Guolin Ke committed
95
    if (!feature_names_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
96
97
98
99
      // erase label column name
      feature_names_.erase(feature_names_.begin() + label_idx_);
      for (size_t i = 0; i < feature_names_.size(); ++i) {
        name2idx[feature_names_[i]] = static_cast<int>(i);
Guolin Ke's avatar
Guolin Ke committed
100
      }
Guolin Ke's avatar
Guolin Ke committed
101
102
103
    }

    // load ignore columns
Guolin Ke's avatar
Guolin Ke committed
104
105
106
    if (config_.ignore_column.size() > 0) {
      if (Common::StartsWith(config_.ignore_column, name_prefix)) {
        std::string names = config_.ignore_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
107
108
109
110
111
112
113
114
115
        for (auto name : Common::Split(names.c_str(), ',')) {
          if (name2idx.count(name) > 0) {
            int tmp = name2idx[name];
            ignore_features_.emplace(tmp);
          } else {
            Log::Fatal("Could not find ignore column %s in data file", name.c_str());
          }
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
116
        for (auto token : Common::Split(config_.ignore_column.c_str(), ',')) {
Guolin Ke's avatar
Guolin Ke committed
117
118
          int tmp = 0;
          if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
119
120
121
            Log::Fatal("ignore_column is not a number,\n"
                       "if you want to use a column name,\n"
                       "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
122
123
          }
          ignore_features_.emplace(tmp);
Guolin Ke's avatar
Guolin Ke committed
124
125
126
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
127
    // load weight idx
Guolin Ke's avatar
Guolin Ke committed
128
129
130
    if (config_.weight_column.size() > 0) {
      if (Common::StartsWith(config_.weight_column, name_prefix)) {
        std::string name = config_.weight_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
131
132
133
134
135
136
        if (name2idx.count(name) > 0) {
          weight_idx_ = name2idx[name];
          Log::Info("Using column %s as weight", name.c_str());
        } else {
          Log::Fatal("Could not find weight column %s in data file", name.c_str());
        }
Guolin Ke's avatar
Guolin Ke committed
137
      } else {
Guolin Ke's avatar
Guolin Ke committed
138
        if (!Common::AtoiAndCheck(config_.weight_column.c_str(), &weight_idx_)) {
139
140
141
          Log::Fatal("weight_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
142
143
        }
        Log::Info("Using column number %d as weight", weight_idx_);
Guolin Ke's avatar
Guolin Ke committed
144
      }
Guolin Ke's avatar
Guolin Ke committed
145
      ignore_features_.emplace(weight_idx_);
Guolin Ke's avatar
Guolin Ke committed
146
    }
Guolin Ke's avatar
Guolin Ke committed
147
    // load group idx
Guolin Ke's avatar
Guolin Ke committed
148
149
150
    if (config_.group_column.size() > 0) {
      if (Common::StartsWith(config_.group_column, name_prefix)) {
        std::string name = config_.group_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
151
152
153
154
155
156
157
        if (name2idx.count(name) > 0) {
          group_idx_ = name2idx[name];
          Log::Info("Using column %s as group/query id", name.c_str());
        } else {
          Log::Fatal("Could not find group/query column %s in data file", name.c_str());
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
158
        if (!Common::AtoiAndCheck(config_.group_column.c_str(), &group_idx_)) {
159
160
161
          Log::Fatal("group_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
162
163
164
165
        }
        Log::Info("Using column number %d as group/query id", group_idx_);
      }
      ignore_features_.emplace(group_idx_);
Guolin Ke's avatar
Guolin Ke committed
166
167
    }
  }
Guolin Ke's avatar
Guolin Ke committed
168
169
170
  if (config_.categorical_feature.size() > 0) {
    if (Common::StartsWith(config_.categorical_feature, name_prefix)) {
      std::string names = config_.categorical_feature.substr(name_prefix.size());
171
172
173
174
175
      for (auto name : Common::Split(names.c_str(), ',')) {
        if (name2idx.count(name) > 0) {
          int tmp = name2idx[name];
          categorical_features_.emplace(tmp);
        } else {
Guolin Ke's avatar
Guolin Ke committed
176
          Log::Fatal("Could not find categorical_feature %s in data file", name.c_str());
177
178
179
        }
      }
    } else {
Guolin Ke's avatar
Guolin Ke committed
180
      for (auto token : Common::Split(config_.categorical_feature.c_str(), ',')) {
181
182
        int tmp = 0;
        if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
Guolin Ke's avatar
Guolin Ke committed
183
          Log::Fatal("categorical_feature is not a number,\n"
184
185
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
186
187
188
189
190
        }
        categorical_features_.emplace(tmp);
      }
    }
  }
Guolin Ke's avatar
Guolin Ke committed
191
192
}

193
194
195
196
197
198
199
200
201
202
void CheckSampleSize(size_t sample_cnt, size_t num_data) {
  if (static_cast<double>(sample_cnt) / num_data < 0.2f &&
      sample_cnt < 100000) {
    Log::Warning(
        "Using too small ``bin_construct_sample_cnt`` may encounter "
        "unexpected "
        "errors and poor accuracy.");
  }
}

203
Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_machines) {
204
  // don't support query id in data file when using distributed training
Guolin Ke's avatar
Guolin Ke committed
205
  if (num_machines > 1 && !config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
206
    if (group_idx_ > 0) {
207
      Log::Fatal("Using a query id without pre-partitioning the data file is not supported for distributed training.\n"
208
                 "Please use an additional query file or pre-partition the data");
Guolin Ke's avatar
Guolin Ke committed
209
210
    }
  }
Guolin Ke's avatar
Guolin Ke committed
211
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
212
213
214
  if (store_raw_) {
    dataset->SetHasRaw(true);
  }
Guolin Ke's avatar
Guolin Ke committed
215
216
  data_size_t num_global_data = 0;
  std::vector<data_size_t> used_data_indices;
217
  auto bin_filename = CheckCanLoadFromBin(filename);
218
  bool is_load_from_binary = false;
219
  if (bin_filename.size() == 0) {
220
    dataset->parser_config_str_ = Parser::GenerateParserConfigStr(filename, config_.parser_config_file.c_str(), config_.header, label_idx_);
Chen Yufei's avatar
Chen Yufei committed
221
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_,
222
                                                               config_.precise_float_parser, dataset->parser_config_str_));
Guolin Ke's avatar
Guolin Ke committed
223
224
225
226
    if (parser == nullptr) {
      Log::Fatal("Could not recognize data format of %s", filename);
    }
    dataset->data_filename_ = filename;
Guolin Ke's avatar
Guolin Ke committed
227
    dataset->label_idx_ = label_idx_;
228
    dataset->metadata_.Init(filename);
Guolin Ke's avatar
Guolin Ke committed
229
    if (!config_.two_round) {
Guolin Ke's avatar
Guolin Ke committed
230
      // read data to memory
231
      auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
232
233
234
      dataset->num_data_ = static_cast<data_size_t>(text_data.size());
      // sample data
      auto sample_data = SampleTextDataFromMemory(text_data);
235
236
      CheckSampleSize(sample_data.size(),
                      static_cast<size_t>(dataset->num_data_));
237
      // construct feature bin mappers & clear sample data
Guolin Ke's avatar
Guolin Ke committed
238
      ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get());
239
      std::vector<std::string>().swap(sample_data);
240
241
242
      if (dataset->has_raw()) {
        dataset->ResizeRaw(dataset->num_data_);
      }
Guolin Ke's avatar
Guolin Ke committed
243
      // initialize label
244
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
245
      // extract features
Guolin Ke's avatar
Guolin Ke committed
246
      ExtractFeaturesFromMemory(&text_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
247
248
249
250
251
252
253
254
255
      text_data.clear();
    } else {
      // sample data from file
      auto sample_data = SampleTextDataFromFile(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
      if (used_data_indices.size() > 0) {
        dataset->num_data_ = static_cast<data_size_t>(used_data_indices.size());
      } else {
        dataset->num_data_ = num_global_data;
      }
256
257
      CheckSampleSize(sample_data.size(),
                      static_cast<size_t>(dataset->num_data_));
258
      // construct feature bin mappers & clear sample data
Guolin Ke's avatar
Guolin Ke committed
259
      ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get());
260
      std::vector<std::string>().swap(sample_data);
261
262
263
      if (dataset->has_raw()) {
        dataset->ResizeRaw(dataset->num_data_);
      }
Guolin Ke's avatar
Guolin Ke committed
264
      // initialize label
265
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
266
      Log::Info("Making second pass...");
Guolin Ke's avatar
Guolin Ke committed
267
      // extract features
Guolin Ke's avatar
Guolin Ke committed
268
      ExtractFeaturesFromFile(filename, parser.get(), used_data_indices, dataset.get());
Guolin Ke's avatar
Guolin Ke committed
269
270
271
    }
  } else {
    // load data from binary file
272
273
    is_load_from_binary = true;
    Log::Info("Load from binary file %s", bin_filename.c_str());
274
    dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), rank, num_machines, &num_global_data, &used_data_indices));
Guolin Ke's avatar
Guolin Ke committed
275
276
277
278
  }
  // check meta data
  dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);
  // need to check training data
279
280
  CheckDataset(dataset.get(), is_load_from_binary);

Guolin Ke's avatar
Guolin Ke committed
281
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
282
283
}

284
Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, const Dataset* train_data) {
Guolin Ke's avatar
Guolin Ke committed
285
286
  data_size_t num_global_data = 0;
  std::vector<data_size_t> used_data_indices;
Guolin Ke's avatar
Guolin Ke committed
287
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
288
289
290
  if (store_raw_) {
    dataset->SetHasRaw(true);
  }
291
292
  auto bin_filename = CheckCanLoadFromBin(filename);
  if (bin_filename.size() == 0) {
Chen Yufei's avatar
Chen Yufei committed
293
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_,
294
                                                               config_.precise_float_parser, dataset->parser_config_str_));
Guolin Ke's avatar
Guolin Ke committed
295
296
297
298
    if (parser == nullptr) {
      Log::Fatal("Could not recognize data format of %s", filename);
    }
    dataset->data_filename_ = filename;
Guolin Ke's avatar
Guolin Ke committed
299
    dataset->label_idx_ = label_idx_;
300
    dataset->metadata_.Init(filename);
Guolin Ke's avatar
Guolin Ke committed
301
    if (!config_.two_round) {
Guolin Ke's avatar
Guolin Ke committed
302
303
304
305
      // read data in memory
      auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, 0, 1, &num_global_data, &used_data_indices);
      dataset->num_data_ = static_cast<data_size_t>(text_data.size());
      // initialize label
306
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
307
      dataset->CreateValid(train_data);
308
309
310
      if (dataset->has_raw()) {
        dataset->ResizeRaw(dataset->num_data_);
      }
Guolin Ke's avatar
Guolin Ke committed
311
      // extract features
Guolin Ke's avatar
Guolin Ke committed
312
      ExtractFeaturesFromMemory(&text_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
313
314
      text_data.clear();
    } else {
Guolin Ke's avatar
Guolin Ke committed
315
      TextReader<data_size_t> text_reader(filename, config_.header);
Guolin Ke's avatar
Guolin Ke committed
316
317
318
319
      // Get number of lines of data file
      dataset->num_data_ = static_cast<data_size_t>(text_reader.CountLine());
      num_global_data = dataset->num_data_;
      // initialize label
320
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
321
      dataset->CreateValid(train_data);
322
323
324
      if (dataset->has_raw()) {
        dataset->ResizeRaw(dataset->num_data_);
      }
Guolin Ke's avatar
Guolin Ke committed
325
      // extract features
Guolin Ke's avatar
Guolin Ke committed
326
      ExtractFeaturesFromFile(filename, parser.get(), used_data_indices, dataset.get());
Guolin Ke's avatar
Guolin Ke committed
327
328
329
    }
  } else {
    // load data from binary file
330
    dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), 0, 1, &num_global_data, &used_data_indices));
Guolin Ke's avatar
Guolin Ke committed
331
332
333
334
  }
  // not need to check validation data
  // check meta data
  dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
335
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
336
337
}

338
339
340
Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* bin_filename,
                                        int rank, int num_machines, int* num_global_data,
                                        std::vector<data_size_t>* used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
341
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
342
  auto reader = VirtualFileReader::Make(bin_filename);
Guolin Ke's avatar
Guolin Ke committed
343
  dataset->data_filename_ = data_filename;
344
  if (!reader->Init()) {
Guolin Ke's avatar
Guolin Ke committed
345
346
347
348
349
    Log::Fatal("Could not read binary data from %s", bin_filename);
  }

  // buffer to read binary file
  size_t buffer_size = 16 * 1024 * 1024;
Guolin Ke's avatar
Guolin Ke committed
350
  auto buffer = std::vector<char>(buffer_size);
351

352
353
  // check token
  size_t size_of_token = std::strlen(Dataset::binary_file_token);
354
355
356
357
  size_t read_cnt = reader->Read(
      buffer.data(),
      VirtualFileWriter::AlignedSize(sizeof(char) * size_of_token));
  if (read_cnt < sizeof(char) * size_of_token) {
358
359
360
    Log::Fatal("Binary file error: token has the wrong size");
  }
  if (std::string(buffer.data()) != std::string(Dataset::binary_file_token)) {
361
    Log::Fatal("Input file is not LightGBM binary file");
362
  }
Guolin Ke's avatar
Guolin Ke committed
363
364

  // read size of header
365
  read_cnt = reader->Read(buffer.data(), sizeof(size_t));
Guolin Ke's avatar
Guolin Ke committed
366

367
  if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
368
369
370
    Log::Fatal("Binary file error: header has the wrong size");
  }

Guolin Ke's avatar
Guolin Ke committed
371
  size_t size_of_head = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
372
373
374
375

  // re-allocmate space if not enough
  if (size_of_head > buffer_size) {
    buffer_size = size_of_head;
Guolin Ke's avatar
Guolin Ke committed
376
    buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
377
378
  }
  // read header
379
  read_cnt = reader->Read(buffer.data(), size_of_head);
Guolin Ke's avatar
Guolin Ke committed
380
381
382
383
384

  if (read_cnt != size_of_head) {
    Log::Fatal("Binary file error: header is incorrect");
  }
  // get header
Guolin Ke's avatar
Guolin Ke committed
385
  const char* mem_ptr = buffer.data();
Guolin Ke's avatar
Guolin Ke committed
386
  dataset->num_data_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
387
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->num_data_));
Guolin Ke's avatar
Guolin Ke committed
388
  dataset->num_features_ = *(reinterpret_cast<const int*>(mem_ptr));
389
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->num_features_));
Guolin Ke's avatar
Guolin Ke committed
390
  dataset->num_total_features_ = *(reinterpret_cast<const int*>(mem_ptr));
391
392
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(dataset->num_total_features_));
Guolin Ke's avatar
Guolin Ke committed
393
  dataset->label_idx_ = *(reinterpret_cast<const int*>(mem_ptr));
394
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->label_idx_));
395
  dataset->max_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
396
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->max_bin_));
397
  dataset->bin_construct_sample_cnt_ = *(reinterpret_cast<const int*>(mem_ptr));
398
399
  mem_ptr += VirtualFileWriter::AlignedSize(
      sizeof(dataset->bin_construct_sample_cnt_));
400
  dataset->min_data_in_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
401
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->min_data_in_bin_));
402
  dataset->use_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
403
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->use_missing_));
404
  dataset->zero_as_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
405
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->zero_as_missing_));
406
407
  dataset->has_raw_ = *(reinterpret_cast<const bool*>(mem_ptr));
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->has_raw_));
Guolin Ke's avatar
Guolin Ke committed
408
409
  const int* tmp_feature_map = reinterpret_cast<const int*>(mem_ptr);
  dataset->used_feature_map_.clear();
Guolin Ke's avatar
Guolin Ke committed
410
  for (int i = 0; i < dataset->num_total_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
411
412
    dataset->used_feature_map_.push_back(tmp_feature_map[i]);
  }
413
414
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int) *
                                            dataset->num_total_features_);
Guolin Ke's avatar
Guolin Ke committed
415
416
  // num_groups
  dataset->num_groups_ = *(reinterpret_cast<const int*>(mem_ptr));
417
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->num_groups_));
Guolin Ke's avatar
Guolin Ke committed
418
419
420
421
422
423
  // real_feature_idx_
  const int* tmp_ptr_real_feature_idx_ = reinterpret_cast<const int*>(mem_ptr);
  dataset->real_feature_idx_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->real_feature_idx_.push_back(tmp_ptr_real_feature_idx_[i]);
  }
424
425
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * dataset->num_features_);
Guolin Ke's avatar
Guolin Ke committed
426
427
428
429
430
431
  // feature2group
  const int* tmp_ptr_feature2group = reinterpret_cast<const int*>(mem_ptr);
  dataset->feature2group_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->feature2group_.push_back(tmp_ptr_feature2group[i]);
  }
432
433
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * dataset->num_features_);
Guolin Ke's avatar
Guolin Ke committed
434
435
436
437
438
439
  // feature2subfeature
  const int* tmp_ptr_feature2subfeature = reinterpret_cast<const int*>(mem_ptr);
  dataset->feature2subfeature_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->feature2subfeature_.push_back(tmp_ptr_feature2subfeature[i]);
  }
440
441
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * dataset->num_features_);
Guolin Ke's avatar
Guolin Ke committed
442
443
444
445
446
447
448
449
450
451
452
  // group_bin_boundaries
  const uint64_t* tmp_ptr_group_bin_boundaries = reinterpret_cast<const uint64_t*>(mem_ptr);
  dataset->group_bin_boundaries_.clear();
  for (int i = 0; i < dataset->num_groups_ + 1; ++i) {
    dataset->group_bin_boundaries_.push_back(tmp_ptr_group_bin_boundaries[i]);
  }
  mem_ptr += sizeof(uint64_t) * (dataset->num_groups_ + 1);

  // group_feature_start_
  const int* tmp_ptr_group_feature_start = reinterpret_cast<const int*>(mem_ptr);
  dataset->group_feature_start_.clear();
453
  for (int i = 0; i < dataset->num_groups_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
454
455
    dataset->group_feature_start_.push_back(tmp_ptr_group_feature_start[i]);
  }
456
457
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * (dataset->num_groups_));
Guolin Ke's avatar
Guolin Ke committed
458
459
460
461
462
463
464

  // group_feature_cnt_
  const int* tmp_ptr_group_feature_cnt = reinterpret_cast<const int*>(mem_ptr);
  dataset->group_feature_cnt_.clear();
  for (int i = 0; i < dataset->num_groups_; ++i) {
    dataset->group_feature_cnt_.push_back(tmp_ptr_group_feature_cnt[i]);
  }
465
466
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * (dataset->num_groups_));
Guolin Ke's avatar
Guolin Ke committed
467

Belinda Trotta's avatar
Belinda Trotta committed
468
  if (!config_.max_bin_by_feature.empty()) {
469
470
    CHECK_EQ(static_cast<size_t>(dataset->num_total_features_), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
Belinda Trotta's avatar
Belinda Trotta committed
471
472
473
474
475
476
477
478
479
    dataset->max_bin_by_feature_.resize(dataset->num_total_features_);
    dataset->max_bin_by_feature_.assign(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end());
  } else {
    const int32_t* tmp_ptr_max_bin_by_feature = reinterpret_cast<const int32_t*>(mem_ptr);
    dataset->max_bin_by_feature_.clear();
    for (int i = 0; i < dataset->num_total_features_; ++i) {
      dataset->max_bin_by_feature_.push_back(tmp_ptr_max_bin_by_feature[i]);
    }
  }
480
481
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int32_t) *
                                            (dataset->num_total_features_));
Belinda Trotta's avatar
Belinda Trotta committed
482
483
484
485
  if (ArrayArgs<int32_t>::CheckAll(dataset->max_bin_by_feature_, -1)) {
    dataset->max_bin_by_feature_.clear();
  }

Guolin Ke's avatar
Guolin Ke committed
486
  // get feature names
Guolin Ke's avatar
Guolin Ke committed
487
  dataset->feature_names_.clear();
Guolin Ke's avatar
Guolin Ke committed
488
489
490
  // write feature names
  for (int i = 0; i < dataset->num_total_features_; ++i) {
    int str_len = *(reinterpret_cast<const int*>(mem_ptr));
491
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int));
Guolin Ke's avatar
Guolin Ke committed
492
    std::stringstream str_buf;
493
    auto tmp_arr = reinterpret_cast<const char*>(mem_ptr);
Guolin Ke's avatar
Guolin Ke committed
494
    for (int j = 0; j < str_len; ++j) {
495
      char tmp_char = tmp_arr[j];
Guolin Ke's avatar
Guolin Ke committed
496
497
      str_buf << tmp_char;
    }
498
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(char) * str_len);
Guolin Ke's avatar
Guolin Ke committed
499
    dataset->feature_names_.emplace_back(str_buf.str());
Guolin Ke's avatar
Guolin Ke committed
500
  }
501
502
503
504
  // get forced_bin_bounds_
  dataset->forced_bin_bounds_ = std::vector<std::vector<double>>(dataset->num_total_features_, std::vector<double>());
  for (int i = 0; i < dataset->num_total_features_; ++i) {
    int num_bounds = *(reinterpret_cast<const int*>(mem_ptr));
505
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int));
506
    dataset->forced_bin_bounds_[i] = std::vector<double>();
507
508
    const double* tmp_ptr_forced_bounds =
        reinterpret_cast<const double*>(mem_ptr);
509
510
511
512
513
514
    for (int j = 0; j < num_bounds; ++j) {
      double bound = tmp_ptr_forced_bounds[j];
      dataset->forced_bin_bounds_[i].push_back(bound);
    }
    mem_ptr += num_bounds * sizeof(double);
  }
Guolin Ke's avatar
Guolin Ke committed
515
516

  // read size of meta data
517
  read_cnt = reader->Read(buffer.data(), sizeof(size_t));
Guolin Ke's avatar
Guolin Ke committed
518

519
  if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
520
521
522
    Log::Fatal("Binary file error: meta data has the wrong size");
  }

Guolin Ke's avatar
Guolin Ke committed
523
  size_t size_of_metadata = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
524
525
526
527

  // re-allocate space if not enough
  if (size_of_metadata > buffer_size) {
    buffer_size = size_of_metadata;
Guolin Ke's avatar
Guolin Ke committed
528
    buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
529
530
  }
  //  read meta data
531
  read_cnt = reader->Read(buffer.data(), size_of_metadata);
Guolin Ke's avatar
Guolin Ke committed
532
533
534
535
536

  if (read_cnt != size_of_metadata) {
    Log::Fatal("Binary file error: meta data is incorrect");
  }
  // load meta data
Guolin Ke's avatar
Guolin Ke committed
537
  dataset->metadata_.LoadFromMemory(buffer.data());
Guolin Ke's avatar
Guolin Ke committed
538

539
540
  *num_global_data = dataset->num_data_;
  used_data_indices->clear();
Guolin Ke's avatar
Guolin Ke committed
541
  // sample local used data if need to partition
Guolin Ke's avatar
Guolin Ke committed
542
  if (num_machines > 1 && !config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
543
544
545
546
    const data_size_t* query_boundaries = dataset->metadata_.query_boundaries();
    if (query_boundaries == nullptr) {
      // if not contain query file, minimal sample unit is one record
      for (data_size_t i = 0; i < dataset->num_data_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
547
        if (random_.NextShort(0, num_machines) == rank) {
548
          used_data_indices->push_back(i);
Guolin Ke's avatar
Guolin Ke committed
549
550
551
552
553
554
555
556
557
        }
      }
    } else {
      // if contain query file, minimal sample unit is one query
      data_size_t num_queries = dataset->metadata_.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      for (data_size_t i = 0; i < dataset->num_data_; ++i) {
        if (qid >= num_queries) {
558
559
          Log::Fatal("Current query exceeds the range of the query file,\n"
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
560
561
562
563
        }
        if (i >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
564
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
565
566
567
568
569
            is_query_used = true;
          }
          ++qid;
        }
        if (is_query_used) {
570
          used_data_indices->push_back(i);
Guolin Ke's avatar
Guolin Ke committed
571
572
573
        }
      }
    }
574
    dataset->num_data_ = static_cast<data_size_t>((*used_data_indices).size());
Guolin Ke's avatar
Guolin Ke committed
575
  }
576
  dataset->metadata_.PartitionLabel(*used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
577
  // read feature data
Guolin Ke's avatar
Guolin Ke committed
578
  for (int i = 0; i < dataset->num_groups_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
579
    // read feature size
580
581
    read_cnt = reader->Read(buffer.data(), sizeof(size_t));
    if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
582
583
      Log::Fatal("Binary file error: feature %d has the wrong size", i);
    }
Guolin Ke's avatar
Guolin Ke committed
584
    size_t size_of_feature = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
585
586
587
    // re-allocate space if not enough
    if (size_of_feature > buffer_size) {
      buffer_size = size_of_feature;
Guolin Ke's avatar
Guolin Ke committed
588
      buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
589
590
    }

591
    read_cnt = reader->Read(buffer.data(), size_of_feature);
Guolin Ke's avatar
Guolin Ke committed
592
593
594
595

    if (read_cnt != size_of_feature) {
      Log::Fatal("Binary file error: feature %d is incorrect, read count: %d", i, read_cnt);
    }
Guolin Ke's avatar
Guolin Ke committed
596
    dataset->feature_groups_.emplace_back(std::unique_ptr<FeatureGroup>(
597
598
      new FeatureGroup(buffer.data(),
                       *num_global_data,
599
                       *used_data_indices, i)));
Guolin Ke's avatar
Guolin Ke committed
600
  }
Guolin Ke's avatar
Guolin Ke committed
601
  dataset->feature_groups_.shrink_to_fit();
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637

  // raw data
  dataset->numeric_feature_map_ = std::vector<int>(dataset->num_features_, false);
  dataset->num_numeric_features_ = 0;
  for (int i = 0; i < dataset->num_features_; ++i) {
    if (dataset->FeatureBinMapper(i)->bin_type() == BinType::CategoricalBin) {
      dataset->numeric_feature_map_[i] = -1;
    } else {
      dataset->numeric_feature_map_[i] = dataset->num_numeric_features_;
      ++dataset->num_numeric_features_;
    }
  }
  if (dataset->has_raw()) {
    dataset->ResizeRaw(dataset->num_data());
      size_t row_size = dataset->num_numeric_features_ * sizeof(float);
      if (row_size > buffer_size) {
        buffer_size = row_size;
        buffer.resize(buffer_size);
      }
    for (int i = 0; i < dataset->num_data(); ++i) {
      read_cnt = reader->Read(buffer.data(), row_size);
      if (read_cnt != row_size) {
        Log::Fatal("Binary file error: row %d of raw data is incorrect, read count: %d", i, read_cnt);
      }
      mem_ptr = buffer.data();
      const float* tmp_ptr_raw_row = reinterpret_cast<const float*>(mem_ptr);
      for (int j = 0; j < dataset->num_features(); ++j) {
        int feat_ind = dataset->numeric_feature_map_[j];
        if (feat_ind >= 0) {
          dataset->raw_data_[feat_ind][i] = tmp_ptr_raw_row[feat_ind];
        }
      }
      mem_ptr += row_size;
    }
  }

Guolin Ke's avatar
Guolin Ke committed
638
  dataset->is_finish_load_ = true;
Guolin Ke's avatar
Guolin Ke committed
639
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
640
641
}

642

643
644
645
Dataset* DatasetLoader::ConstructFromSampleData(double** sample_values,
                                                int** sample_indices, int num_col, const int* num_per_col,
                                                size_t total_sample_size, data_size_t num_data) {
646
  CheckSampleSize(total_sample_size, static_cast<size_t>(num_data));
647
648
649
650
651
  int num_total_features = num_col;
  if (Network::num_machines() > 1) {
    num_total_features = Network::GlobalSyncUpByMax(num_total_features);
  }
  std::vector<std::unique_ptr<BinMapper>> bin_mappers(num_total_features);
652
653
  // fill feature_names_ if not header
  if (feature_names_.empty()) {
654
    for (int i = 0; i < num_col; ++i) {
655
656
657
658
659
      std::stringstream str_buf;
      str_buf << "Column_" << i;
      feature_names_.push_back(str_buf.str());
    }
  }
Belinda Trotta's avatar
Belinda Trotta committed
660
  if (!config_.max_bin_by_feature.empty()) {
661
662
    CHECK_EQ(static_cast<size_t>(num_col), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
Belinda Trotta's avatar
Belinda Trotta committed
663
  }
664
665
666
667
668

  // get forced split
  std::string forced_bins_path = config_.forcedbins_filename;
  std::vector<std::vector<double>> forced_bin_bounds = DatasetLoader::GetForcedBins(forced_bins_path, num_col, categorical_features_);

Guolin Ke's avatar
Guolin Ke committed
669
  const data_size_t filter_cnt = static_cast<data_size_t>(
Guolin Ke's avatar
Guolin Ke committed
670
    static_cast<double>(config_.min_data_in_leaf * total_sample_size) / num_data);
671
672
673
674
675
676
677
678
679
680
681
682
683
  if (Network::num_machines() == 1) {
    // if only one machine, find bin locally
    OMP_INIT_EX();
    #pragma omp parallel for schedule(guided)
    for (int i = 0; i < num_col; ++i) {
      OMP_LOOP_EX_BEGIN();
      if (ignore_features_.count(i) > 0) {
        bin_mappers[i] = nullptr;
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(i)) {
        bin_type = BinType::CategoricalBin;
684
685
686
687
        bool feat_is_unconstrained = ((config_.monotone_constraints.size() == 0) || (config_.monotone_constraints[i] == 0));
        if (!feat_is_unconstrained) {
            Log::Fatal("The output cannot be monotone with respect to categorical features");
        }
688
689
      }
      bin_mappers[i].reset(new BinMapper());
Belinda Trotta's avatar
Belinda Trotta committed
690
691
      if (config_.max_bin_by_feature.empty()) {
        bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
692
                                config_.max_bin, config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter,
693
694
                                bin_type, config_.use_missing, config_.zero_as_missing,
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
695
696
      } else {
        bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
697
                                config_.max_bin_by_feature[i], config_.min_data_in_bin,
698
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
699
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
700
      }
701
702
703
704
705
706
707
708
709
710
711
712
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();
  } else {
    // if have multi-machines, need to find bin distributed
    // different machines will find bin for different features
    int num_machines = Network::num_machines();
    int rank = Network::rank();
    // start and len will store the process feature indices for different machines
    // machine i will find bins for features in [ start[i], start[i] + len[i] )
    std::vector<int> start(num_machines);
    std::vector<int> len(num_machines);
713
    int step = (num_total_features + num_machines - 1) / num_machines;
714
715
716
717
    if (step < 1) { step = 1; }

    start[0] = 0;
    for (int i = 0; i < num_machines - 1; ++i) {
718
      len[i] = std::min(step, num_total_features - start[i]);
719
720
      start[i + 1] = start[i] + len[i];
    }
721
    len[num_machines - 1] = num_total_features - start[num_machines - 1];
722
723
724
725
726
727
728
729
730
731
732
733
    OMP_INIT_EX();
    #pragma omp parallel for schedule(guided)
    for (int i = 0; i < len[rank]; ++i) {
      OMP_LOOP_EX_BEGIN();
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(start[rank] + i)) {
        bin_type = BinType::CategoricalBin;
      }
      bin_mappers[i].reset(new BinMapper());
734
735
736
      if (num_col <= start[rank] + i) {
        continue;
      }
Belinda Trotta's avatar
Belinda Trotta committed
737
      if (config_.max_bin_by_feature.empty()) {
738
739
        bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
                                total_sample_size, config_.max_bin, config_.min_data_in_bin,
740
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
741
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
742
      } else {
743
744
        bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
                                total_sample_size, config_.max_bin_by_feature[start[rank] + i],
745
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
746
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
747
      }
748
749
      OMP_LOOP_EX_END();
    }
Guolin Ke's avatar
Guolin Ke committed
750
    OMP_THROW_EX();
751
    comm_size_t self_buf_size = 0;
752
    for (int i = 0; i < len[rank]; ++i) {
753
754
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
755
      }
756
      self_buf_size += static_cast<comm_size_t>(bin_mappers[i]->SizesInByte());
Guolin Ke's avatar
Guolin Ke committed
757
    }
758
759
    std::vector<char> input_buffer(self_buf_size);
    auto cp_ptr = input_buffer.data();
760
761
762
763
    for (int i = 0; i < len[rank]; ++i) {
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
764
765
      bin_mappers[i]->CopyTo(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
766
767
768
      // free
      bin_mappers[i].reset(nullptr);
    }
769
770
771
772
    std::vector<comm_size_t> size_len = Network::GlobalArray(self_buf_size);
    std::vector<comm_size_t> size_start(num_machines, 0);
    for (int i = 1; i < num_machines; ++i) {
      size_start[i] = size_start[i - 1] + size_len[i - 1];
773
    }
774
775
    comm_size_t total_buffer_size = size_start[num_machines - 1] + size_len[num_machines - 1];
    std::vector<char> output_buffer(total_buffer_size);
776
    // gather global feature bin mappers
777
778
    Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), total_buffer_size);
    cp_ptr = output_buffer.data();
779
    // restore features bins from buffer
780
    for (int i = 0; i < num_total_features; ++i) {
781
782
783
784
785
      if (ignore_features_.count(i) > 0) {
        bin_mappers[i] = nullptr;
        continue;
      }
      bin_mappers[i].reset(new BinMapper());
786
787
      bin_mappers[i]->CopyFrom(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
788
    }
Guolin Ke's avatar
Guolin Ke committed
789
  }
Guolin Ke's avatar
Guolin Ke committed
790
  auto dataset = std::unique_ptr<Dataset>(new Dataset(num_data));
Guolin Ke's avatar
Guolin Ke committed
791
  dataset->Construct(&bin_mappers, num_total_features, forced_bin_bounds, sample_indices, sample_values, num_per_col, num_col, total_sample_size, config_);
792
793
794
  if (dataset->has_raw()) {
    dataset->ResizeRaw(num_data);
  }
795
  dataset->set_feature_names(feature_names_);
Guolin Ke's avatar
Guolin Ke committed
796
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
797
}
Guolin Ke's avatar
Guolin Ke committed
798
799
800
801


// ---- private functions ----

802
void DatasetLoader::CheckDataset(const Dataset* dataset, bool is_load_from_binary) {
Guolin Ke's avatar
Guolin Ke committed
803
  if (dataset->num_data_ <= 0) {
Guolin Ke's avatar
Guolin Ke committed
804
    Log::Fatal("Data file %s is empty", dataset->data_filename_.c_str());
Guolin Ke's avatar
Guolin Ke committed
805
  }
806
807
  if (dataset->feature_names_.size() != static_cast<size_t>(dataset->num_total_features_)) {
    Log::Fatal("Size of feature name error, should be %d, got %d", dataset->num_total_features_,
808
               static_cast<int>(dataset->feature_names_.size()));
809
  }
Guolin Ke's avatar
Guolin Ke committed
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
  bool is_feature_order_by_group = true;
  int last_group = -1;
  int last_sub_feature = -1;
  // if features are ordered, not need to use hist_buf
  for (int i = 0; i < dataset->num_features_; ++i) {
    int group = dataset->feature2group_[i];
    int sub_feature = dataset->feature2subfeature_[i];
    if (group < last_group) {
      is_feature_order_by_group = false;
    } else if (group == last_group) {
      if (sub_feature <= last_sub_feature) {
        is_feature_order_by_group = false;
        break;
      }
    }
    last_group = group;
    last_sub_feature = sub_feature;
  }
  if (!is_feature_order_by_group) {
829
    Log::Fatal("Features in dataset should be ordered by group");
Guolin Ke's avatar
Guolin Ke committed
830
  }
831
832
833

  if (is_load_from_binary) {
    if (dataset->max_bin_ != config_.max_bin) {
834
835
      Log::Fatal("Dataset was constructed with parameter max_bin=%d. It cannot be changed to %d when loading from binary file.",
                 dataset->max_bin_, config_.max_bin);
836
837
    }
    if (dataset->min_data_in_bin_ != config_.min_data_in_bin) {
838
839
      Log::Fatal("Dataset was constructed with parameter min_data_in_bin=%d. It cannot be changed to %d when loading from binary file.",
                 dataset->min_data_in_bin_, config_.min_data_in_bin);
840
841
    }
    if (dataset->use_missing_ != config_.use_missing) {
842
843
      Log::Fatal("Dataset was constructed with parameter use_missing=%d. It cannot be changed to %d when loading from binary file.",
                 dataset->use_missing_, config_.use_missing);
844
845
    }
    if (dataset->zero_as_missing_ != config_.zero_as_missing) {
846
847
      Log::Fatal("Dataset was constructed with parameter zero_as_missing=%d. It cannot be changed to %d when loading from binary file.",
                 dataset->zero_as_missing_, config_.zero_as_missing);
848
849
    }
    if (dataset->bin_construct_sample_cnt_ != config_.bin_construct_sample_cnt) {
850
851
      Log::Fatal("Dataset was constructed with parameter bin_construct_sample_cnt=%d. It cannot be changed to %d when loading from binary file.",
                 dataset->bin_construct_sample_cnt_, config_.bin_construct_sample_cnt);
852
853
854
855
    }
    if ((dataset->max_bin_by_feature_.size() != config_.max_bin_by_feature.size()) ||
        !std::equal(dataset->max_bin_by_feature_.begin(), dataset->max_bin_by_feature_.end(),
            config_.max_bin_by_feature.begin())) {
856
      Log::Fatal("Parameter max_bin_by_feature cannot be changed when loading from binary file.");
857
858
    }

859
    if (config_.label_column != "") {
860
      Log::Warning("Parameter label_column works only in case of loading data directly from text file. It will be ignored when loading from binary file.");
861
862
    }
    if (config_.weight_column != "") {
863
      Log::Warning("Parameter weight_column works only in case of loading data directly from text file. It will be ignored when loading from binary file.");
864
865
    }
    if (config_.group_column != "") {
866
      Log::Warning("Parameter group_column works only in case of loading data directly from text file. It will be ignored when loading from binary file.");
867
868
    }
    if (config_.ignore_column != "") {
869
      Log::Warning("Parameter ignore_column works only in case of loading data directly from text file. It will be ignored when loading from binary file.");
870
    }
871
    if (config_.two_round) {
872
      Log::Warning("Parameter two_round works only in case of loading data directly from text file. It will be ignored when loading from binary file.");
873
874
    }
    if (config_.header) {
875
      Log::Warning("Parameter header works only in case of loading data directly from text file. It will be ignored when loading from binary file.");
876
    }
877
  }
Guolin Ke's avatar
Guolin Ke committed
878
879
880
}

std::vector<std::string> DatasetLoader::LoadTextDataToMemory(const char* filename, const Metadata& metadata,
881
882
                                                             int rank, int num_machines, int* num_global_data,
                                                             std::vector<data_size_t>* used_data_indices) {
883
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
884
  used_data_indices->clear();
Guolin Ke's avatar
Guolin Ke committed
885
  if (num_machines == 1 || config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
886
887
888
889
890
891
892
893
894
    // read all lines
    *num_global_data = text_reader.ReadAllLines();
  } else {  // need partition data
            // get query data
    const data_size_t* query_boundaries = metadata.query_boundaries();

    if (query_boundaries == nullptr) {
      // if not contain query data, minimal sample unit is one record
      *num_global_data = text_reader.ReadAndFilterLines([this, rank, num_machines](data_size_t) {
Guolin Ke's avatar
Guolin Ke committed
895
        if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
896
897
898
899
900
901
902
903
904
905
906
907
908
909
          return true;
        } else {
          return false;
        }
      }, used_data_indices);
    } else {
      // if contain query data, minimal sample unit is one query
      data_size_t num_queries = metadata.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      *num_global_data = text_reader.ReadAndFilterLines(
        [this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries]
      (data_size_t line_idx) {
        if (qid >= num_queries) {
910
911
          Log::Fatal("Current query exceeds the range of the query file,\n"
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
912
913
914
915
        }
        if (line_idx >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
916
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
917
918
919
920
921
922
923
924
925
926
927
928
            is_query_used = true;
          }
          ++qid;
        }
        return is_query_used;
      }, used_data_indices);
    }
  }
  return std::move(text_reader.Lines());
}

std::vector<std::string> DatasetLoader::SampleTextDataFromMemory(const std::vector<std::string>& data) {
Guolin Ke's avatar
Guolin Ke committed
929
  int sample_cnt = config_.bin_construct_sample_cnt;
930
931
  if (static_cast<size_t>(sample_cnt) > data.size()) {
    sample_cnt = static_cast<int>(data.size());
932
  }
933
  auto sample_indices = random_.Sample(static_cast<int>(data.size()), sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
934
  std::vector<std::string> out(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
935
936
  for (size_t i = 0; i < sample_indices.size(); ++i) {
    const size_t idx = sample_indices[i];
Guolin Ke's avatar
Guolin Ke committed
937
    out[i] = data[idx];
Guolin Ke's avatar
Guolin Ke committed
938
939
940
941
  }
  return out;
}

942
943
944
std::vector<std::string> DatasetLoader::SampleTextDataFromFile(const char* filename, const Metadata& metadata,
                                                               int rank, int num_machines, int* num_global_data,
                                                               std::vector<data_size_t>* used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
945
  const data_size_t sample_cnt = static_cast<data_size_t>(config_.bin_construct_sample_cnt);
946
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
947
  std::vector<std::string> out_data;
Guolin Ke's avatar
Guolin Ke committed
948
  if (num_machines == 1 || config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
949
    *num_global_data = static_cast<data_size_t>(text_reader.SampleFromFile(&random_, sample_cnt, &out_data));
Guolin Ke's avatar
Guolin Ke committed
950
951
952
953
954
955
956
  } else {  // need partition data
            // get query data
    const data_size_t* query_boundaries = metadata.query_boundaries();
    if (query_boundaries == nullptr) {
      // if not contain query file, minimal sample unit is one record
      *num_global_data = text_reader.SampleAndFilterFromFile([this, rank, num_machines]
      (data_size_t) {
Guolin Ke's avatar
Guolin Ke committed
957
        if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
958
959
960
961
          return true;
        } else {
          return false;
        }
Guolin Ke's avatar
Guolin Ke committed
962
      }, used_data_indices, &random_, sample_cnt, &out_data);
Guolin Ke's avatar
Guolin Ke committed
963
964
965
966
967
968
969
970
971
    } else {
      // if contain query file, minimal sample unit is one query
      data_size_t num_queries = metadata.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      *num_global_data = text_reader.SampleAndFilterFromFile(
        [this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries]
      (data_size_t line_idx) {
        if (qid >= num_queries) {
972
973
          Log::Fatal("Query id exceeds the range of the query file, "
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
974
975
976
977
        }
        if (line_idx >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
978
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
979
980
981
982
983
            is_query_used = true;
          }
          ++qid;
        }
        return is_query_used;
Guolin Ke's avatar
Guolin Ke committed
984
      }, used_data_indices, &random_, sample_cnt, &out_data);
Guolin Ke's avatar
Guolin Ke committed
985
986
987
988
989
    }
  }
  return out_data;
}

990
991
992
void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
                                                    const std::vector<std::string>& sample_data,
                                                    const Parser* parser, Dataset* dataset) {
993
  auto t1 = std::chrono::high_resolution_clock::now();
Guolin Ke's avatar
Guolin Ke committed
994
  std::vector<std::vector<double>> sample_values;
Guolin Ke's avatar
Guolin Ke committed
995
  std::vector<std::vector<int>> sample_indices;
Guolin Ke's avatar
Guolin Ke committed
996
997
  std::vector<std::pair<int, double>> oneline_features;
  double label;
Guolin Ke's avatar
Guolin Ke committed
998
  for (int i = 0; i < static_cast<int>(sample_data.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
999
1000
1001
1002
    oneline_features.clear();
    // parse features
    parser->ParseOneLine(sample_data[i].c_str(), &oneline_features, &label);
    for (std::pair<int, double>& inner_data : oneline_features) {
1003
      if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
Guolin Ke's avatar
Guolin Ke committed
1004
1005
        sample_values.resize(inner_data.first + 1);
        sample_indices.resize(inner_data.first + 1);
1006
      }
Guolin Ke's avatar
Guolin Ke committed
1007
      if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
Guolin Ke's avatar
Guolin Ke committed
1008
1009
        sample_values[inner_data.first].emplace_back(inner_data.second);
        sample_indices[inner_data.first].emplace_back(i);
Guolin Ke's avatar
Guolin Ke committed
1010
1011
1012
1013
      }
    }
  }

Guolin Ke's avatar
Guolin Ke committed
1014
  dataset->feature_groups_.clear();
1015
1016
1017
1018
1019
  dataset->num_total_features_ = std::max(static_cast<int>(sample_values.size()), parser->NumFeatures());
  if (num_machines > 1) {
    dataset->num_total_features_ = Network::GlobalSyncUpByMax(dataset->num_total_features_);
  }
  if (!feature_names_.empty()) {
1020
    CHECK_EQ(dataset->num_total_features_, static_cast<int>(feature_names_.size()));
1021
  }
Guolin Ke's avatar
Guolin Ke committed
1022

Belinda Trotta's avatar
Belinda Trotta committed
1023
  if (!config_.max_bin_by_feature.empty()) {
1024
1025
    CHECK_EQ(static_cast<size_t>(dataset->num_total_features_), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
Belinda Trotta's avatar
Belinda Trotta committed
1026
1027
  }

1028
1029
  // get forced split
  std::string forced_bins_path = config_.forcedbins_filename;
1030
1031
  std::vector<std::vector<double>> forced_bin_bounds = DatasetLoader::GetForcedBins(forced_bins_path,
                                                                                    dataset->num_total_features_,
1032
1033
                                                                                    categorical_features_);

Guolin Ke's avatar
Guolin Ke committed
1034
  // check the range of label_idx, weight_idx and group_idx
1035
1036
1037
1038
1039
  // skip label check if user input parser config file,
  // because label id is got from raw features while dataset features are consistent with customized parser.
  if (dataset->parser_config_str_.empty()) {
    CHECK(label_idx_ >= 0 && label_idx_ <= dataset->num_total_features_);
  }
Guolin Ke's avatar
Guolin Ke committed
1040
1041
1042
1043
  CHECK(weight_idx_ < 0 || weight_idx_ < dataset->num_total_features_);
  CHECK(group_idx_ < 0 || group_idx_ < dataset->num_total_features_);

  // fill feature_names_ if not header
Guolin Ke's avatar
Guolin Ke committed
1044
  if (feature_names_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1045
1046
1047
1048
1049
1050
    for (int i = 0; i < dataset->num_total_features_; ++i) {
      std::stringstream str_buf;
      str_buf << "Column_" << i;
      feature_names_.push_back(str_buf.str());
    }
  }
1051
  dataset->set_feature_names(feature_names_);
Guolin Ke's avatar
Guolin Ke committed
1052
  std::vector<std::unique_ptr<BinMapper>> bin_mappers(dataset->num_total_features_);
Guolin Ke's avatar
Guolin Ke committed
1053
  const data_size_t filter_cnt = static_cast<data_size_t>(
Guolin Ke's avatar
Guolin Ke committed
1054
    static_cast<double>(config_.min_data_in_leaf* sample_data.size()) / dataset->num_data_);
Guolin Ke's avatar
Guolin Ke committed
1055
1056
1057
  // start find bins
  if (num_machines == 1) {
    // if only one machine, find bin locally
1058
    OMP_INIT_EX();
1059
    #pragma omp parallel for schedule(guided)
Guolin Ke's avatar
Guolin Ke committed
1060
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
1061
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1062
      if (ignore_features_.count(i) > 0) {
Guolin Ke's avatar
Guolin Ke committed
1063
        bin_mappers[i] = nullptr;
Guolin Ke's avatar
Guolin Ke committed
1064
1065
        continue;
      }
1066
1067
1068
1069
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(i)) {
        bin_type = BinType::CategoricalBin;
      }
Guolin Ke's avatar
Guolin Ke committed
1070
      bin_mappers[i].reset(new BinMapper());
Belinda Trotta's avatar
Belinda Trotta committed
1071
1072
      if (config_.max_bin_by_feature.empty()) {
        bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
1073
                                sample_data.size(), config_.max_bin, config_.min_data_in_bin,
1074
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
1075
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
1076
1077
      } else {
        bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
1078
                                sample_data.size(), config_.max_bin_by_feature[i],
1079
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
1080
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
1081
      }
1082
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1083
    }
1084
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1085
1086
  } else {
    // start and len will store the process feature indices for different machines
1087
    // machine i will find bins for features in [ start[i], start[i] + len[i] )
Guolin Ke's avatar
Guolin Ke committed
1088
1089
    std::vector<int> start(num_machines);
    std::vector<int> len(num_machines);
1090
    int step = (dataset->num_total_features_ + num_machines - 1) / num_machines;
Guolin Ke's avatar
Guolin Ke committed
1091
1092
1093
1094
    if (step < 1) { step = 1; }

    start[0] = 0;
    for (int i = 0; i < num_machines - 1; ++i) {
1095
      len[i] = std::min(step, dataset->num_total_features_ - start[i]);
Guolin Ke's avatar
Guolin Ke committed
1096
1097
      start[i + 1] = start[i] + len[i];
    }
1098
    len[num_machines - 1] = dataset->num_total_features_ - start[num_machines - 1];
1099
    OMP_INIT_EX();
1100
    #pragma omp parallel for schedule(guided)
1101
    for (int i = 0; i < len[rank]; ++i) {
1102
      OMP_LOOP_EX_BEGIN();
1103
1104
1105
1106
1107
1108
1109
1110
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(start[rank] + i)) {
        bin_type = BinType::CategoricalBin;
      }
      bin_mappers[i].reset(new BinMapper());
Nikita Titov's avatar
Nikita Titov committed
1111
      if (static_cast<int>(sample_values.size()) <= start[rank] + i) {
1112
1113
        continue;
      }
Belinda Trotta's avatar
Belinda Trotta committed
1114
      if (config_.max_bin_by_feature.empty()) {
1115
        bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
Belinda Trotta's avatar
Belinda Trotta committed
1116
                                static_cast<int>(sample_values[start[rank] + i].size()),
1117
                                sample_data.size(), config_.max_bin, config_.min_data_in_bin,
1118
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
1119
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
1120
      } else {
1121
        bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
Belinda Trotta's avatar
Belinda Trotta committed
1122
                                static_cast<int>(sample_values[start[rank] + i].size()),
1123
                                sample_data.size(), config_.max_bin_by_feature[i],
1124
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type,
1125
                                config_.use_missing, config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
1126
      }
1127
      OMP_LOOP_EX_END();
1128
    }
1129
    OMP_THROW_EX();
1130
    comm_size_t self_buf_size = 0;
Guolin Ke's avatar
Guolin Ke committed
1131
    for (int i = 0; i < len[rank]; ++i) {
1132
1133
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
Guolin Ke's avatar
Guolin Ke committed
1134
      }
1135
      self_buf_size += static_cast<comm_size_t>(bin_mappers[i]->SizesInByte());
Guolin Ke's avatar
Guolin Ke committed
1136
    }
1137
1138
    std::vector<char> input_buffer(self_buf_size);
    auto cp_ptr = input_buffer.data();
Guolin Ke's avatar
Guolin Ke committed
1139
    for (int i = 0; i < len[rank]; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1140
1141
1142
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
1143
1144
      bin_mappers[i]->CopyTo(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
1145
1146
      // free
      bin_mappers[i].reset(nullptr);
Guolin Ke's avatar
Guolin Ke committed
1147
    }
1148
1149
1150
1151
    std::vector<comm_size_t> size_len = Network::GlobalArray(self_buf_size);
    std::vector<comm_size_t> size_start(num_machines, 0);
    for (int i = 1; i < num_machines; ++i) {
      size_start[i] = size_start[i - 1] + size_len[i - 1];
Guolin Ke's avatar
Guolin Ke committed
1152
    }
1153
1154
    comm_size_t total_buffer_size = size_start[num_machines - 1] + size_len[num_machines - 1];
    std::vector<char> output_buffer(total_buffer_size);
Guolin Ke's avatar
Guolin Ke committed
1155
    // gather global feature bin mappers
1156
1157
    Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), total_buffer_size);
    cp_ptr = output_buffer.data();
Guolin Ke's avatar
Guolin Ke committed
1158
    // restore features bins from buffer
1159
    for (int i = 0; i < dataset->num_total_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1160
      if (ignore_features_.count(i) > 0) {
Guolin Ke's avatar
Guolin Ke committed
1161
        bin_mappers[i] = nullptr;
Guolin Ke's avatar
Guolin Ke committed
1162
1163
        continue;
      }
Guolin Ke's avatar
Guolin Ke committed
1164
      bin_mappers[i].reset(new BinMapper());
1165
1166
      bin_mappers[i]->CopyFrom(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
Guolin Ke's avatar
Guolin Ke committed
1167
1168
    }
  }
1169
  dataset->Construct(&bin_mappers, dataset->num_total_features_, forced_bin_bounds, Common::Vector2Ptr<int>(&sample_indices).data(),
Guolin Ke's avatar
Guolin Ke committed
1170
                     Common::Vector2Ptr<double>(&sample_values).data(),
1171
                     Common::VectorSize<int>(sample_indices).data(), static_cast<int>(sample_indices.size()), sample_data.size(), config_);
1172
  if (dataset->has_raw()) {
1173
    dataset->ResizeRaw(static_cast<int>(sample_data.size()));
1174
  }
1175
1176
1177
1178

  auto t2 = std::chrono::high_resolution_clock::now();
  Log::Info("Construct bin mappers from text data time %.2f seconds",
            std::chrono::duration<double, std::milli>(t2 - t1) * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
1179
1180
1181
}

/*! \brief Extract local features from memory */
Guolin Ke's avatar
Guolin Ke committed
1182
void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>* text_data, const Parser* parser, Dataset* dataset) {
Guolin Ke's avatar
Guolin Ke committed
1183
1184
  std::vector<std::pair<int, double>> oneline_features;
  double tmp_label = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
1185
  auto& ref_text_data = *text_data;
1186
  std::vector<float> feature_row(dataset->num_features_);
1187
  if (!predict_fun_) {
1188
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1189
    // if doesn't need to prediction with initial model
1190
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label, feature_row)
Guolin Ke's avatar
Guolin Ke committed
1191
    for (data_size_t i = 0; i < dataset->num_data_; ++i) {
1192
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1193
1194
1195
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
Guolin Ke's avatar
Guolin Ke committed
1196
      parser->ParseOneLine(ref_text_data[i].c_str(), &oneline_features, &tmp_label);
Guolin Ke's avatar
Guolin Ke committed
1197
      // set label
1198
      dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1199
      // free processed line:
Guolin Ke's avatar
Guolin Ke committed
1200
      ref_text_data[i].clear();
Guolin Ke's avatar
Guolin Ke committed
1201
1202
      // shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
      // text_reader_->Lines()[i].shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
1203
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1204
1205
      // push data
      for (auto& inner_data : oneline_features) {
1206
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1207
1208
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1209
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1210
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1211
1212
1213
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
          dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
1214
          if (dataset->has_raw()) {
1215
            feature_row[feature_idx] = static_cast<float>(inner_data.second);
1216
          }
Guolin Ke's avatar
Guolin Ke committed
1217
1218
        } else {
          if (inner_data.first == weight_idx_) {
1219
            dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1220
1221
1222
1223
1224
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
1225
1226
1227
1228
1229
1230
1231
1232
      if (dataset->has_raw()) {
        for (size_t j = 0; j < feature_row.size(); ++j) {
          int feat_ind = dataset->numeric_feature_map_[j];
          if (feat_ind >= 0) {
            dataset->raw_data_[feat_ind][i] = feature_row[j];
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1233
      dataset->FinishOneRow(tid, i, is_feature_added);
1234
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1235
    }
1236
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1237
  } else {
1238
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1239
    // if need to prediction with initial model
1240
    std::vector<double> init_score(dataset->num_data_ * num_class_);
1241
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label, feature_row)
Guolin Ke's avatar
Guolin Ke committed
1242
    for (data_size_t i = 0; i < dataset->num_data_; ++i) {
1243
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1244
1245
1246
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
Guolin Ke's avatar
Guolin Ke committed
1247
      parser->ParseOneLine(ref_text_data[i].c_str(), &oneline_features, &tmp_label);
Guolin Ke's avatar
Guolin Ke committed
1248
      // set initial score
Guolin Ke's avatar
Guolin Ke committed
1249
1250
      std::vector<double> oneline_init_score(num_class_);
      predict_fun_(oneline_features, oneline_init_score.data());
1251
      for (int k = 0; k < num_class_; ++k) {
1252
        init_score[k * dataset->num_data_ + i] = static_cast<double>(oneline_init_score[k]);
Guolin Ke's avatar
Guolin Ke committed
1253
1254
      }
      // set label
1255
      dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1256
      // free processed line:
1257
      ref_text_data[i].clear();
Andrew Ziem's avatar
Andrew Ziem committed
1258
      // shrink_to_fit will be very slow in Linux, and seems not free memory, disable for now
Guolin Ke's avatar
Guolin Ke committed
1259
1260
      // text_reader_->Lines()[i].shrink_to_fit();
      // push data
Guolin Ke's avatar
Guolin Ke committed
1261
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1262
      for (auto& inner_data : oneline_features) {
1263
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1264
1265
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1266
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1267
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1268
1269
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
1270
          dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
1271
          if (dataset->has_raw()) {
1272
            feature_row[feature_idx] = static_cast<float>(inner_data.second);
1273
          }
Guolin Ke's avatar
Guolin Ke committed
1274
1275
        } else {
          if (inner_data.first == weight_idx_) {
1276
            dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1277
1278
1279
1280
1281
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1282
      dataset->FinishOneRow(tid, i, is_feature_added);
1283
1284
1285
1286
1287
1288
1289
1290
      if (dataset->has_raw()) {
        for (size_t j = 0; j < feature_row.size(); ++j) {
          int feat_ind = dataset->numeric_feature_map_[j];
          if (feat_ind >= 0) {
            dataset->raw_data_[feat_ind][i] = feature_row[j];
          }
        }
      }
1291
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1292
    }
1293
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1294
    // metadata_ will manage space of init_score
1295
    dataset->metadata_.SetInitScore(init_score.data(), dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1296
  }
Guolin Ke's avatar
Guolin Ke committed
1297
  dataset->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1298
  // text data can be free after loaded feature values
Guolin Ke's avatar
Guolin Ke committed
1299
  text_data->clear();
Guolin Ke's avatar
Guolin Ke committed
1300
1301
1302
}

/*! \brief Extract local features from file */
1303
1304
void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser* parser,
                                            const std::vector<data_size_t>& used_data_indices, Dataset* dataset) {
1305
  std::vector<double> init_score;
1306
  if (predict_fun_) {
1307
    init_score = std::vector<double>(dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1308
1309
1310
1311
1312
1313
  }
  std::function<void(data_size_t, const std::vector<std::string>&)> process_fun =
    [this, &init_score, &parser, &dataset]
  (data_size_t start_idx, const std::vector<std::string>& lines) {
    std::vector<std::pair<int, double>> oneline_features;
    double tmp_label = 0.0f;
1314
    std::vector<float> feature_row(dataset->num_features_);
1315
    OMP_INIT_EX();
1316
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label, feature_row)
Guolin Ke's avatar
Guolin Ke committed
1317
    for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) {
1318
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1319
1320
1321
1322
1323
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
      parser->ParseOneLine(lines[i].c_str(), &oneline_features, &tmp_label);
      // set initial score
Guolin Ke's avatar
Guolin Ke committed
1324
      if (!init_score.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1325
1326
        std::vector<double> oneline_init_score(num_class_);
        predict_fun_(oneline_features, oneline_init_score.data());
1327
        for (int k = 0; k < num_class_; ++k) {
1328
          init_score[k * dataset->num_data_ + start_idx + i] = static_cast<double>(oneline_init_score[k]);
Guolin Ke's avatar
Guolin Ke committed
1329
1330
1331
        }
      }
      // set label
1332
      dataset->metadata_.SetLabelAt(start_idx + i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1333
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1334
1335
      // push data
      for (auto& inner_data : oneline_features) {
1336
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1337
1338
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1339
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1340
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1341
1342
1343
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
          dataset->feature_groups_[group]->PushData(tid, sub_feature, start_idx + i, inner_data.second);
1344
          if (dataset->has_raw()) {
1345
            feature_row[feature_idx] = static_cast<float>(inner_data.second);
1346
          }
Guolin Ke's avatar
Guolin Ke committed
1347
1348
        } else {
          if (inner_data.first == weight_idx_) {
1349
            dataset->metadata_.SetWeightAt(start_idx + i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1350
1351
1352
1353
1354
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(start_idx + i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
1355
1356
1357
1358
1359
1360
1361
1362
      if (dataset->has_raw()) {
        for (size_t j = 0; j < feature_row.size(); ++j) {
          int feat_ind = dataset->numeric_feature_map_[j];
          if (feat_ind >= 0) {
            dataset->raw_data_[feat_ind][i] = feature_row[j];
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1363
      dataset->FinishOneRow(tid, i, is_feature_added);
1364
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1365
    }
1366
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1367
  };
1368
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
1369
  if (!used_data_indices.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1370
1371
1372
1373
1374
1375
1376
1377
    // only need part of data
    text_reader.ReadPartAndProcessParallel(used_data_indices, process_fun);
  } else {
    // need full data
    text_reader.ReadAllAndProcessParallel(process_fun);
  }

  // metadata_ will manage space of init_score
Guolin Ke's avatar
Guolin Ke committed
1378
  if (!init_score.empty()) {
1379
    dataset->metadata_.SetInitScore(init_score.data(), dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1380
  }
Guolin Ke's avatar
Guolin Ke committed
1381
  dataset->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1382
1383
1384
}

/*! \brief Check can load from binary file */
1385
std::string DatasetLoader::CheckCanLoadFromBin(const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
1386
1387
1388
  std::string bin_filename(filename);
  bin_filename.append(".bin");

1389
  auto reader = VirtualFileReader::Make(bin_filename.c_str());
Guolin Ke's avatar
Guolin Ke committed
1390

1391
  if (!reader->Init()) {
1392
    bin_filename = std::string(filename);
1393
1394
    reader = VirtualFileReader::Make(bin_filename.c_str());
    if (!reader->Init()) {
1395
      Log::Fatal("Cannot open data file %s", bin_filename.c_str());
1396
    }
1397
  }
1398
1399
1400
1401
1402

  size_t buffer_size = 256;
  auto buffer = std::vector<char>(buffer_size);
  // read size of token
  size_t size_of_token = std::strlen(Dataset::binary_file_token);
1403
  size_t read_cnt = reader->Read(buffer.data(), size_of_token);
1404
1405
  if (read_cnt == size_of_token
      && std::string(buffer.data()) == std::string(Dataset::binary_file_token)) {
1406
    return bin_filename;
Guolin Ke's avatar
Guolin Ke committed
1407
  } else {
1408
    return std::string();
Guolin Ke's avatar
Guolin Ke committed
1409
1410
1411
  }
}

1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
std::vector<std::vector<double>> DatasetLoader::GetForcedBins(std::string forced_bins_path, int num_total_features,
                                                              const std::unordered_set<int>& categorical_features) {
  std::vector<std::vector<double>> forced_bins(num_total_features, std::vector<double>());
  if (forced_bins_path != "") {
    std::ifstream forced_bins_stream(forced_bins_path.c_str());
    if (forced_bins_stream.fail()) {
      Log::Warning("Could not open %s. Will ignore.", forced_bins_path.c_str());
    } else {
      std::stringstream buffer;
      buffer << forced_bins_stream.rdbuf();
      std::string err;
Guolin Ke's avatar
Guolin Ke committed
1423
      Json forced_bins_json = Json::parse(buffer.str(), &err);
1424
1425
1426
1427
      CHECK(forced_bins_json.is_array());
      std::vector<Json> forced_bins_arr = forced_bins_json.array_items();
      for (size_t i = 0; i < forced_bins_arr.size(); ++i) {
        int feature_num = forced_bins_arr[i]["feature"].int_value();
Nikita Titov's avatar
Nikita Titov committed
1428
        CHECK_LT(feature_num, num_total_features);
1429
        if (categorical_features.count(feature_num)) {
1430
          Log::Warning("Feature %d is categorical. Will ignore forced bins for this feature.", feature_num);
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
        } else {
          std::vector<Json> bounds_arr = forced_bins_arr[i]["bin_upper_bound"].array_items();
          for (size_t j = 0; j < bounds_arr.size(); ++j) {
            forced_bins[feature_num].push_back(bounds_arr[j].number_value());
          }
        }
      }
      // remove duplicates
      for (int i = 0; i < num_total_features; ++i) {
        auto new_end = std::unique(forced_bins[i].begin(), forced_bins[i].end());
        forced_bins[i].erase(new_end, forced_bins[i].end());
      }
    }
  }
  return forced_bins;
}

1448
}  // namespace LightGBM