dataset_loader.cpp 61.9 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
#include <LightGBM/dataset_loader.h>

Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/network.h>
8
#include <LightGBM/utils/array_args.h>
9
#include <LightGBM/utils/json11.h>
10
11
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
12

13
#include <chrono>
14
15
#include <fstream>

Guolin Ke's avatar
Guolin Ke committed
16
17
namespace LightGBM {

18
19
using json11::Json;

Guolin Ke's avatar
Guolin Ke committed
20
21
DatasetLoader::DatasetLoader(const Config& io_config, const PredictFunction& predict_fun, int num_class, const char* filename)
  :config_(io_config), random_(config_.data_random_seed), predict_fun_(predict_fun), num_class_(num_class) {
Guolin Ke's avatar
Guolin Ke committed
22
23
24
25
  label_idx_ = 0;
  weight_idx_ = NO_SPECIFIC;
  group_idx_ = NO_SPECIFIC;
  SetHeader(filename);
26
27
28
29
  store_raw_ = false;
  if (io_config.linear_tree) {
    store_raw_ = true;
  }
Guolin Ke's avatar
Guolin Ke committed
30
31
32
33
34
}

DatasetLoader::~DatasetLoader() {
}

Guolin Ke's avatar
Guolin Ke committed
35
void DatasetLoader::SetHeader(const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
36
  std::unordered_map<std::string, int> name2idx;
Guolin Ke's avatar
Guolin Ke committed
37
  std::string name_prefix("name:");
38
  if (filename != nullptr && CheckCanLoadFromBin(filename) == "") {
Guolin Ke's avatar
Guolin Ke committed
39
    TextReader<data_size_t> text_reader(filename, config_.header);
Guolin Ke's avatar
Guolin Ke committed
40

Guolin Ke's avatar
Guolin Ke committed
41
    // get column names
Guolin Ke's avatar
Guolin Ke committed
42
    if (config_.header) {
Guolin Ke's avatar
Guolin Ke committed
43
      std::string first_line = text_reader.first_line();
44
      feature_names_ = Common::Split(first_line.c_str(), "\t,");
45
46
47
48
49
50
51
52
53
54
55
56
    } else if (!config_.parser_config_file.empty()) {
      // support to get header from parser config, so could utilize following label name to id mapping logic.
      TextReader<data_size_t> parser_config_reader(config_.parser_config_file.c_str(), false);
      parser_config_reader.ReadAllLines();
      std::string parser_config_str = parser_config_reader.JoinedLines();
      if (!parser_config_str.empty()) {
        std::string header_in_parser_config = Common::GetFromParserConfig(parser_config_str, "header");
        if (!header_in_parser_config.empty()) {
          Log::Info("Get raw column names from parser config.");
          feature_names_ = Common::Split(header_in_parser_config.c_str(), "\t,");
        }
      }
Guolin Ke's avatar
Guolin Ke committed
57
58
    }

Guolin Ke's avatar
Guolin Ke committed
59
    // load label idx first
Guolin Ke's avatar
Guolin Ke committed
60
61
62
    if (config_.label_column.size() > 0) {
      if (Common::StartsWith(config_.label_column, name_prefix)) {
        std::string name = config_.label_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
63
64
65
66
67
68
69
70
71
72
        label_idx_ = -1;
        for (int i = 0; i < static_cast<int>(feature_names_.size()); ++i) {
          if (name == feature_names_[i]) {
            label_idx_ = i;
            break;
          }
        }
        if (label_idx_ >= 0) {
          Log::Info("Using column %s as label", name.c_str());
        } else {
73
74
          Log::Fatal("Could not find label column %s in data file \n"
                     "or data file doesn't contain header", name.c_str());
Guolin Ke's avatar
Guolin Ke committed
75
        }
Guolin Ke's avatar
Guolin Ke committed
76
      } else {
Guolin Ke's avatar
Guolin Ke committed
77
        if (!Common::AtoiAndCheck(config_.label_column.c_str(), &label_idx_)) {
78
79
80
          Log::Fatal("label_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
81
82
        }
        Log::Info("Using column number %d as label", label_idx_);
Guolin Ke's avatar
Guolin Ke committed
83
84
      }
    }
Guolin Ke's avatar
Guolin Ke committed
85

86
87
88
89
90
91
92
93
94
    if (!config_.parser_config_file.empty()) {
      // if parser config file exists, feature names may be changed after customized parser applied.
      // clear here so could use default filled feature names during dataset construction.
      // may improve by saving real feature names defined in parser in the future.
      if (!feature_names_.empty()) {
        feature_names_.clear();
      }
    }

Guolin Ke's avatar
Guolin Ke committed
95
    if (!feature_names_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
96
97
98
99
      // erase label column name
      feature_names_.erase(feature_names_.begin() + label_idx_);
      for (size_t i = 0; i < feature_names_.size(); ++i) {
        name2idx[feature_names_[i]] = static_cast<int>(i);
Guolin Ke's avatar
Guolin Ke committed
100
      }
Guolin Ke's avatar
Guolin Ke committed
101
102
103
    }

    // load ignore columns
Guolin Ke's avatar
Guolin Ke committed
104
105
106
    if (config_.ignore_column.size() > 0) {
      if (Common::StartsWith(config_.ignore_column, name_prefix)) {
        std::string names = config_.ignore_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
107
108
109
110
111
112
113
114
115
        for (auto name : Common::Split(names.c_str(), ',')) {
          if (name2idx.count(name) > 0) {
            int tmp = name2idx[name];
            ignore_features_.emplace(tmp);
          } else {
            Log::Fatal("Could not find ignore column %s in data file", name.c_str());
          }
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
116
        for (auto token : Common::Split(config_.ignore_column.c_str(), ',')) {
Guolin Ke's avatar
Guolin Ke committed
117
118
          int tmp = 0;
          if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
119
120
121
            Log::Fatal("ignore_column is not a number,\n"
                       "if you want to use a column name,\n"
                       "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
122
123
          }
          ignore_features_.emplace(tmp);
Guolin Ke's avatar
Guolin Ke committed
124
125
126
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
127
    // load weight idx
Guolin Ke's avatar
Guolin Ke committed
128
129
130
    if (config_.weight_column.size() > 0) {
      if (Common::StartsWith(config_.weight_column, name_prefix)) {
        std::string name = config_.weight_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
131
132
133
134
135
136
        if (name2idx.count(name) > 0) {
          weight_idx_ = name2idx[name];
          Log::Info("Using column %s as weight", name.c_str());
        } else {
          Log::Fatal("Could not find weight column %s in data file", name.c_str());
        }
Guolin Ke's avatar
Guolin Ke committed
137
      } else {
Guolin Ke's avatar
Guolin Ke committed
138
        if (!Common::AtoiAndCheck(config_.weight_column.c_str(), &weight_idx_)) {
139
140
141
          Log::Fatal("weight_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
142
143
        }
        Log::Info("Using column number %d as weight", weight_idx_);
Guolin Ke's avatar
Guolin Ke committed
144
      }
Guolin Ke's avatar
Guolin Ke committed
145
      ignore_features_.emplace(weight_idx_);
Guolin Ke's avatar
Guolin Ke committed
146
    }
Guolin Ke's avatar
Guolin Ke committed
147
    // load group idx
Guolin Ke's avatar
Guolin Ke committed
148
149
150
    if (config_.group_column.size() > 0) {
      if (Common::StartsWith(config_.group_column, name_prefix)) {
        std::string name = config_.group_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
151
152
153
154
155
156
157
        if (name2idx.count(name) > 0) {
          group_idx_ = name2idx[name];
          Log::Info("Using column %s as group/query id", name.c_str());
        } else {
          Log::Fatal("Could not find group/query column %s in data file", name.c_str());
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
158
        if (!Common::AtoiAndCheck(config_.group_column.c_str(), &group_idx_)) {
159
160
161
          Log::Fatal("group_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
162
163
164
165
        }
        Log::Info("Using column number %d as group/query id", group_idx_);
      }
      ignore_features_.emplace(group_idx_);
Guolin Ke's avatar
Guolin Ke committed
166
167
    }
  }
Guolin Ke's avatar
Guolin Ke committed
168
169
170
  if (config_.categorical_feature.size() > 0) {
    if (Common::StartsWith(config_.categorical_feature, name_prefix)) {
      std::string names = config_.categorical_feature.substr(name_prefix.size());
171
172
173
174
175
      for (auto name : Common::Split(names.c_str(), ',')) {
        if (name2idx.count(name) > 0) {
          int tmp = name2idx[name];
          categorical_features_.emplace(tmp);
        } else {
Guolin Ke's avatar
Guolin Ke committed
176
          Log::Fatal("Could not find categorical_feature %s in data file", name.c_str());
177
178
179
        }
      }
    } else {
Guolin Ke's avatar
Guolin Ke committed
180
      for (auto token : Common::Split(config_.categorical_feature.c_str(), ',')) {
181
182
        int tmp = 0;
        if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
Guolin Ke's avatar
Guolin Ke committed
183
          Log::Fatal("categorical_feature is not a number,\n"
184
185
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
186
187
188
189
190
        }
        categorical_features_.emplace(tmp);
      }
    }
  }
Guolin Ke's avatar
Guolin Ke committed
191
192
}

193
194
195
196
197
198
199
200
201
202
void CheckSampleSize(size_t sample_cnt, size_t num_data) {
  if (static_cast<double>(sample_cnt) / num_data < 0.2f &&
      sample_cnt < 100000) {
    Log::Warning(
        "Using too small ``bin_construct_sample_cnt`` may encounter "
        "unexpected "
        "errors and poor accuracy.");
  }
}

203
Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_machines) {
204
  // don't support query id in data file when using distributed training
Guolin Ke's avatar
Guolin Ke committed
205
  if (num_machines > 1 && !config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
206
    if (group_idx_ > 0) {
207
      Log::Fatal("Using a query id without pre-partitioning the data file is not supported for distributed training.\n"
208
                 "Please use an additional query file or pre-partition the data");
Guolin Ke's avatar
Guolin Ke committed
209
210
    }
  }
Guolin Ke's avatar
Guolin Ke committed
211
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
212
213
214
  if (store_raw_) {
    dataset->SetHasRaw(true);
  }
Guolin Ke's avatar
Guolin Ke committed
215
216
  data_size_t num_global_data = 0;
  std::vector<data_size_t> used_data_indices;
217
  auto bin_filename = CheckCanLoadFromBin(filename);
218
  bool is_load_from_binary = false;
219
  if (bin_filename.size() == 0) {
220
    dataset->parser_config_str_ = Parser::GenerateParserConfigStr(filename, config_.parser_config_file.c_str(), config_.header, label_idx_);
Chen Yufei's avatar
Chen Yufei committed
221
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_,
222
                                                               config_.precise_float_parser, dataset->parser_config_str_));
Guolin Ke's avatar
Guolin Ke committed
223
224
225
226
    if (parser == nullptr) {
      Log::Fatal("Could not recognize data format of %s", filename);
    }
    dataset->data_filename_ = filename;
Guolin Ke's avatar
Guolin Ke committed
227
    dataset->label_idx_ = label_idx_;
228
    dataset->metadata_.Init(filename);
Guolin Ke's avatar
Guolin Ke committed
229
    if (!config_.two_round) {
Guolin Ke's avatar
Guolin Ke committed
230
      // read data to memory
231
      auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
232
233
234
      dataset->num_data_ = static_cast<data_size_t>(text_data.size());
      // sample data
      auto sample_data = SampleTextDataFromMemory(text_data);
235
236
      CheckSampleSize(sample_data.size(),
                      static_cast<size_t>(dataset->num_data_));
Guolin Ke's avatar
Guolin Ke committed
237
      // construct feature bin mappers
Guolin Ke's avatar
Guolin Ke committed
238
      ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get());
239
240
241
      if (dataset->has_raw()) {
        dataset->ResizeRaw(dataset->num_data_);
      }
Guolin Ke's avatar
Guolin Ke committed
242
      // initialize label
243
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
244
      // extract features
Guolin Ke's avatar
Guolin Ke committed
245
      ExtractFeaturesFromMemory(&text_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
246
247
248
249
250
251
252
253
254
      text_data.clear();
    } else {
      // sample data from file
      auto sample_data = SampleTextDataFromFile(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
      if (used_data_indices.size() > 0) {
        dataset->num_data_ = static_cast<data_size_t>(used_data_indices.size());
      } else {
        dataset->num_data_ = num_global_data;
      }
255
256
      CheckSampleSize(sample_data.size(),
                      static_cast<size_t>(dataset->num_data_));
Guolin Ke's avatar
Guolin Ke committed
257
      // construct feature bin mappers
Guolin Ke's avatar
Guolin Ke committed
258
      ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get());
259
260
261
      if (dataset->has_raw()) {
        dataset->ResizeRaw(dataset->num_data_);
      }
Guolin Ke's avatar
Guolin Ke committed
262
      // initialize label
263
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
264
      Log::Info("Making second pass...");
Guolin Ke's avatar
Guolin Ke committed
265
      // extract features
Guolin Ke's avatar
Guolin Ke committed
266
      ExtractFeaturesFromFile(filename, parser.get(), used_data_indices, dataset.get());
Guolin Ke's avatar
Guolin Ke committed
267
268
269
    }
  } else {
    // load data from binary file
270
271
    is_load_from_binary = true;
    Log::Info("Load from binary file %s", bin_filename.c_str());
272
    dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), rank, num_machines, &num_global_data, &used_data_indices));
Guolin Ke's avatar
Guolin Ke committed
273
274
275
276
  }
  // check meta data
  dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);
  // need to check training data
277
278
  CheckDataset(dataset.get(), is_load_from_binary);

Guolin Ke's avatar
Guolin Ke committed
279
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
280
281
}

282
Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, const Dataset* train_data) {
Guolin Ke's avatar
Guolin Ke committed
283
284
  data_size_t num_global_data = 0;
  std::vector<data_size_t> used_data_indices;
Guolin Ke's avatar
Guolin Ke committed
285
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
286
287
288
  if (store_raw_) {
    dataset->SetHasRaw(true);
  }
289
290
  auto bin_filename = CheckCanLoadFromBin(filename);
  if (bin_filename.size() == 0) {
Chen Yufei's avatar
Chen Yufei committed
291
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_,
292
                                                               config_.precise_float_parser, dataset->parser_config_str_));
Guolin Ke's avatar
Guolin Ke committed
293
294
295
296
    if (parser == nullptr) {
      Log::Fatal("Could not recognize data format of %s", filename);
    }
    dataset->data_filename_ = filename;
Guolin Ke's avatar
Guolin Ke committed
297
    dataset->label_idx_ = label_idx_;
298
    dataset->metadata_.Init(filename);
Guolin Ke's avatar
Guolin Ke committed
299
    if (!config_.two_round) {
Guolin Ke's avatar
Guolin Ke committed
300
301
302
303
      // read data in memory
      auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, 0, 1, &num_global_data, &used_data_indices);
      dataset->num_data_ = static_cast<data_size_t>(text_data.size());
      // initialize label
304
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
305
      dataset->CreateValid(train_data);
306
307
308
      if (dataset->has_raw()) {
        dataset->ResizeRaw(dataset->num_data_);
      }
Guolin Ke's avatar
Guolin Ke committed
309
      // extract features
Guolin Ke's avatar
Guolin Ke committed
310
      ExtractFeaturesFromMemory(&text_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
311
312
      text_data.clear();
    } else {
Guolin Ke's avatar
Guolin Ke committed
313
      TextReader<data_size_t> text_reader(filename, config_.header);
Guolin Ke's avatar
Guolin Ke committed
314
315
316
317
      // Get number of lines of data file
      dataset->num_data_ = static_cast<data_size_t>(text_reader.CountLine());
      num_global_data = dataset->num_data_;
      // initialize label
318
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
319
      dataset->CreateValid(train_data);
320
321
322
      if (dataset->has_raw()) {
        dataset->ResizeRaw(dataset->num_data_);
      }
Guolin Ke's avatar
Guolin Ke committed
323
      // extract features
Guolin Ke's avatar
Guolin Ke committed
324
      ExtractFeaturesFromFile(filename, parser.get(), used_data_indices, dataset.get());
Guolin Ke's avatar
Guolin Ke committed
325
326
327
    }
  } else {
    // load data from binary file
328
    dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), 0, 1, &num_global_data, &used_data_indices));
Guolin Ke's avatar
Guolin Ke committed
329
330
331
332
  }
  // not need to check validation data
  // check meta data
  dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
333
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
334
335
}

336
337
338
Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* bin_filename,
                                        int rank, int num_machines, int* num_global_data,
                                        std::vector<data_size_t>* used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
339
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
340
  auto reader = VirtualFileReader::Make(bin_filename);
Guolin Ke's avatar
Guolin Ke committed
341
  dataset->data_filename_ = data_filename;
342
  if (!reader->Init()) {
Guolin Ke's avatar
Guolin Ke committed
343
344
345
346
347
    Log::Fatal("Could not read binary data from %s", bin_filename);
  }

  // buffer to read binary file
  size_t buffer_size = 16 * 1024 * 1024;
Guolin Ke's avatar
Guolin Ke committed
348
  auto buffer = std::vector<char>(buffer_size);
349

350
351
  // check token
  size_t size_of_token = std::strlen(Dataset::binary_file_token);
352
353
354
355
  size_t read_cnt = reader->Read(
      buffer.data(),
      VirtualFileWriter::AlignedSize(sizeof(char) * size_of_token));
  if (read_cnt < sizeof(char) * size_of_token) {
356
357
358
    Log::Fatal("Binary file error: token has the wrong size");
  }
  if (std::string(buffer.data()) != std::string(Dataset::binary_file_token)) {
359
    Log::Fatal("Input file is not LightGBM binary file");
360
  }
Guolin Ke's avatar
Guolin Ke committed
361
362

  // read size of header
363
  read_cnt = reader->Read(buffer.data(), sizeof(size_t));
Guolin Ke's avatar
Guolin Ke committed
364

365
  if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
366
367
368
    Log::Fatal("Binary file error: header has the wrong size");
  }

Guolin Ke's avatar
Guolin Ke committed
369
  size_t size_of_head = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
370
371
372
373

  // re-allocmate space if not enough
  if (size_of_head > buffer_size) {
    buffer_size = size_of_head;
Guolin Ke's avatar
Guolin Ke committed
374
    buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
375
376
  }
  // read header
377
  read_cnt = reader->Read(buffer.data(), size_of_head);
Guolin Ke's avatar
Guolin Ke committed
378
379
380
381
382

  if (read_cnt != size_of_head) {
    Log::Fatal("Binary file error: header is incorrect");
  }
  // get header
Guolin Ke's avatar
Guolin Ke committed
383
  const char* mem_ptr = buffer.data();
Guolin Ke's avatar
Guolin Ke committed
384
  dataset->num_data_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
385
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->num_data_));
Guolin Ke's avatar
Guolin Ke committed
386
  dataset->num_features_ = *(reinterpret_cast<const int*>(mem_ptr));
387
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->num_features_));
Guolin Ke's avatar
Guolin Ke committed
388
  dataset->num_total_features_ = *(reinterpret_cast<const int*>(mem_ptr));
389
390
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(dataset->num_total_features_));
Guolin Ke's avatar
Guolin Ke committed
391
  dataset->label_idx_ = *(reinterpret_cast<const int*>(mem_ptr));
392
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->label_idx_));
393
  dataset->max_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
394
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->max_bin_));
395
  dataset->bin_construct_sample_cnt_ = *(reinterpret_cast<const int*>(mem_ptr));
396
397
  mem_ptr += VirtualFileWriter::AlignedSize(
      sizeof(dataset->bin_construct_sample_cnt_));
398
  dataset->min_data_in_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
399
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->min_data_in_bin_));
400
  dataset->use_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
401
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->use_missing_));
402
  dataset->zero_as_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
403
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->zero_as_missing_));
404
405
  dataset->has_raw_ = *(reinterpret_cast<const bool*>(mem_ptr));
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->has_raw_));
Guolin Ke's avatar
Guolin Ke committed
406
407
  const int* tmp_feature_map = reinterpret_cast<const int*>(mem_ptr);
  dataset->used_feature_map_.clear();
Guolin Ke's avatar
Guolin Ke committed
408
  for (int i = 0; i < dataset->num_total_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
409
410
    dataset->used_feature_map_.push_back(tmp_feature_map[i]);
  }
411
412
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int) *
                                            dataset->num_total_features_);
Guolin Ke's avatar
Guolin Ke committed
413
414
  // num_groups
  dataset->num_groups_ = *(reinterpret_cast<const int*>(mem_ptr));
415
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->num_groups_));
Guolin Ke's avatar
Guolin Ke committed
416
417
418
419
420
421
  // real_feature_idx_
  const int* tmp_ptr_real_feature_idx_ = reinterpret_cast<const int*>(mem_ptr);
  dataset->real_feature_idx_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->real_feature_idx_.push_back(tmp_ptr_real_feature_idx_[i]);
  }
422
423
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * dataset->num_features_);
Guolin Ke's avatar
Guolin Ke committed
424
425
426
427
428
429
  // feature2group
  const int* tmp_ptr_feature2group = reinterpret_cast<const int*>(mem_ptr);
  dataset->feature2group_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->feature2group_.push_back(tmp_ptr_feature2group[i]);
  }
430
431
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * dataset->num_features_);
Guolin Ke's avatar
Guolin Ke committed
432
433
434
435
436
437
  // feature2subfeature
  const int* tmp_ptr_feature2subfeature = reinterpret_cast<const int*>(mem_ptr);
  dataset->feature2subfeature_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->feature2subfeature_.push_back(tmp_ptr_feature2subfeature[i]);
  }
438
439
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * dataset->num_features_);
Guolin Ke's avatar
Guolin Ke committed
440
441
442
443
444
445
446
447
448
449
450
  // group_bin_boundaries
  const uint64_t* tmp_ptr_group_bin_boundaries = reinterpret_cast<const uint64_t*>(mem_ptr);
  dataset->group_bin_boundaries_.clear();
  for (int i = 0; i < dataset->num_groups_ + 1; ++i) {
    dataset->group_bin_boundaries_.push_back(tmp_ptr_group_bin_boundaries[i]);
  }
  mem_ptr += sizeof(uint64_t) * (dataset->num_groups_ + 1);

  // group_feature_start_
  const int* tmp_ptr_group_feature_start = reinterpret_cast<const int*>(mem_ptr);
  dataset->group_feature_start_.clear();
451
  for (int i = 0; i < dataset->num_groups_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
452
453
    dataset->group_feature_start_.push_back(tmp_ptr_group_feature_start[i]);
  }
454
455
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * (dataset->num_groups_));
Guolin Ke's avatar
Guolin Ke committed
456
457
458
459
460
461
462

  // group_feature_cnt_
  const int* tmp_ptr_group_feature_cnt = reinterpret_cast<const int*>(mem_ptr);
  dataset->group_feature_cnt_.clear();
  for (int i = 0; i < dataset->num_groups_; ++i) {
    dataset->group_feature_cnt_.push_back(tmp_ptr_group_feature_cnt[i]);
  }
463
464
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * (dataset->num_groups_));
Guolin Ke's avatar
Guolin Ke committed
465

Belinda Trotta's avatar
Belinda Trotta committed
466
  if (!config_.max_bin_by_feature.empty()) {
467
468
    CHECK_EQ(static_cast<size_t>(dataset->num_total_features_), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
Belinda Trotta's avatar
Belinda Trotta committed
469
470
471
472
473
474
475
476
477
    dataset->max_bin_by_feature_.resize(dataset->num_total_features_);
    dataset->max_bin_by_feature_.assign(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end());
  } else {
    const int32_t* tmp_ptr_max_bin_by_feature = reinterpret_cast<const int32_t*>(mem_ptr);
    dataset->max_bin_by_feature_.clear();
    for (int i = 0; i < dataset->num_total_features_; ++i) {
      dataset->max_bin_by_feature_.push_back(tmp_ptr_max_bin_by_feature[i]);
    }
  }
478
479
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int32_t) *
                                            (dataset->num_total_features_));
Belinda Trotta's avatar
Belinda Trotta committed
480
481
482
483
  if (ArrayArgs<int32_t>::CheckAll(dataset->max_bin_by_feature_, -1)) {
    dataset->max_bin_by_feature_.clear();
  }

Guolin Ke's avatar
Guolin Ke committed
484
  // get feature names
Guolin Ke's avatar
Guolin Ke committed
485
  dataset->feature_names_.clear();
Guolin Ke's avatar
Guolin Ke committed
486
487
488
  // write feature names
  for (int i = 0; i < dataset->num_total_features_; ++i) {
    int str_len = *(reinterpret_cast<const int*>(mem_ptr));
489
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int));
Guolin Ke's avatar
Guolin Ke committed
490
    std::stringstream str_buf;
491
    auto tmp_arr = reinterpret_cast<const char*>(mem_ptr);
Guolin Ke's avatar
Guolin Ke committed
492
    for (int j = 0; j < str_len; ++j) {
493
      char tmp_char = tmp_arr[j];
Guolin Ke's avatar
Guolin Ke committed
494
495
      str_buf << tmp_char;
    }
496
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(char) * str_len);
Guolin Ke's avatar
Guolin Ke committed
497
    dataset->feature_names_.emplace_back(str_buf.str());
Guolin Ke's avatar
Guolin Ke committed
498
  }
499
500
501
502
  // get forced_bin_bounds_
  dataset->forced_bin_bounds_ = std::vector<std::vector<double>>(dataset->num_total_features_, std::vector<double>());
  for (int i = 0; i < dataset->num_total_features_; ++i) {
    int num_bounds = *(reinterpret_cast<const int*>(mem_ptr));
503
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int));
504
    dataset->forced_bin_bounds_[i] = std::vector<double>();
505
506
    const double* tmp_ptr_forced_bounds =
        reinterpret_cast<const double*>(mem_ptr);
507
508
509
510
511
512
    for (int j = 0; j < num_bounds; ++j) {
      double bound = tmp_ptr_forced_bounds[j];
      dataset->forced_bin_bounds_[i].push_back(bound);
    }
    mem_ptr += num_bounds * sizeof(double);
  }
Guolin Ke's avatar
Guolin Ke committed
513
514

  // read size of meta data
515
  read_cnt = reader->Read(buffer.data(), sizeof(size_t));
Guolin Ke's avatar
Guolin Ke committed
516

517
  if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
518
519
520
    Log::Fatal("Binary file error: meta data has the wrong size");
  }

Guolin Ke's avatar
Guolin Ke committed
521
  size_t size_of_metadata = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
522
523
524
525

  // re-allocate space if not enough
  if (size_of_metadata > buffer_size) {
    buffer_size = size_of_metadata;
Guolin Ke's avatar
Guolin Ke committed
526
    buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
527
528
  }
  //  read meta data
529
  read_cnt = reader->Read(buffer.data(), size_of_metadata);
Guolin Ke's avatar
Guolin Ke committed
530
531
532
533
534

  if (read_cnt != size_of_metadata) {
    Log::Fatal("Binary file error: meta data is incorrect");
  }
  // load meta data
Guolin Ke's avatar
Guolin Ke committed
535
  dataset->metadata_.LoadFromMemory(buffer.data());
Guolin Ke's avatar
Guolin Ke committed
536

537
538
  *num_global_data = dataset->num_data_;
  used_data_indices->clear();
Guolin Ke's avatar
Guolin Ke committed
539
  // sample local used data if need to partition
Guolin Ke's avatar
Guolin Ke committed
540
  if (num_machines > 1 && !config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
541
542
543
544
    const data_size_t* query_boundaries = dataset->metadata_.query_boundaries();
    if (query_boundaries == nullptr) {
      // if not contain query file, minimal sample unit is one record
      for (data_size_t i = 0; i < dataset->num_data_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
545
        if (random_.NextShort(0, num_machines) == rank) {
546
          used_data_indices->push_back(i);
Guolin Ke's avatar
Guolin Ke committed
547
548
549
550
551
552
553
554
555
        }
      }
    } else {
      // if contain query file, minimal sample unit is one query
      data_size_t num_queries = dataset->metadata_.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      for (data_size_t i = 0; i < dataset->num_data_; ++i) {
        if (qid >= num_queries) {
556
557
          Log::Fatal("Current query exceeds the range of the query file,\n"
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
558
559
560
561
        }
        if (i >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
562
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
563
564
565
566
567
            is_query_used = true;
          }
          ++qid;
        }
        if (is_query_used) {
568
          used_data_indices->push_back(i);
Guolin Ke's avatar
Guolin Ke committed
569
570
571
        }
      }
    }
572
    dataset->num_data_ = static_cast<data_size_t>((*used_data_indices).size());
Guolin Ke's avatar
Guolin Ke committed
573
  }
574
  dataset->metadata_.PartitionLabel(*used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
575
  // read feature data
Guolin Ke's avatar
Guolin Ke committed
576
  for (int i = 0; i < dataset->num_groups_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
577
    // read feature size
578
579
    read_cnt = reader->Read(buffer.data(), sizeof(size_t));
    if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
580
581
      Log::Fatal("Binary file error: feature %d has the wrong size", i);
    }
Guolin Ke's avatar
Guolin Ke committed
582
    size_t size_of_feature = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
583
584
585
    // re-allocate space if not enough
    if (size_of_feature > buffer_size) {
      buffer_size = size_of_feature;
Guolin Ke's avatar
Guolin Ke committed
586
      buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
587
588
    }

589
    read_cnt = reader->Read(buffer.data(), size_of_feature);
Guolin Ke's avatar
Guolin Ke committed
590
591
592
593

    if (read_cnt != size_of_feature) {
      Log::Fatal("Binary file error: feature %d is incorrect, read count: %d", i, read_cnt);
    }
Guolin Ke's avatar
Guolin Ke committed
594
    dataset->feature_groups_.emplace_back(std::unique_ptr<FeatureGroup>(
595
596
      new FeatureGroup(buffer.data(),
                       *num_global_data,
597
                       *used_data_indices, i)));
Guolin Ke's avatar
Guolin Ke committed
598
  }
Guolin Ke's avatar
Guolin Ke committed
599
  dataset->feature_groups_.shrink_to_fit();
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635

  // raw data
  dataset->numeric_feature_map_ = std::vector<int>(dataset->num_features_, false);
  dataset->num_numeric_features_ = 0;
  for (int i = 0; i < dataset->num_features_; ++i) {
    if (dataset->FeatureBinMapper(i)->bin_type() == BinType::CategoricalBin) {
      dataset->numeric_feature_map_[i] = -1;
    } else {
      dataset->numeric_feature_map_[i] = dataset->num_numeric_features_;
      ++dataset->num_numeric_features_;
    }
  }
  if (dataset->has_raw()) {
    dataset->ResizeRaw(dataset->num_data());
      size_t row_size = dataset->num_numeric_features_ * sizeof(float);
      if (row_size > buffer_size) {
        buffer_size = row_size;
        buffer.resize(buffer_size);
      }
    for (int i = 0; i < dataset->num_data(); ++i) {
      read_cnt = reader->Read(buffer.data(), row_size);
      if (read_cnt != row_size) {
        Log::Fatal("Binary file error: row %d of raw data is incorrect, read count: %d", i, read_cnt);
      }
      mem_ptr = buffer.data();
      const float* tmp_ptr_raw_row = reinterpret_cast<const float*>(mem_ptr);
      for (int j = 0; j < dataset->num_features(); ++j) {
        int feat_ind = dataset->numeric_feature_map_[j];
        if (feat_ind >= 0) {
          dataset->raw_data_[feat_ind][i] = tmp_ptr_raw_row[feat_ind];
        }
      }
      mem_ptr += row_size;
    }
  }

Guolin Ke's avatar
Guolin Ke committed
636
  dataset->is_finish_load_ = true;
Guolin Ke's avatar
Guolin Ke committed
637
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
638
639
}

640

641
642
643
Dataset* DatasetLoader::ConstructFromSampleData(double** sample_values,
                                                int** sample_indices, int num_col, const int* num_per_col,
                                                size_t total_sample_size, data_size_t num_data) {
644
  CheckSampleSize(total_sample_size, static_cast<size_t>(num_data));
645
646
647
648
649
  int num_total_features = num_col;
  if (Network::num_machines() > 1) {
    num_total_features = Network::GlobalSyncUpByMax(num_total_features);
  }
  std::vector<std::unique_ptr<BinMapper>> bin_mappers(num_total_features);
650
651
  // fill feature_names_ if not header
  if (feature_names_.empty()) {
652
    for (int i = 0; i < num_col; ++i) {
653
654
655
656
657
      std::stringstream str_buf;
      str_buf << "Column_" << i;
      feature_names_.push_back(str_buf.str());
    }
  }
Belinda Trotta's avatar
Belinda Trotta committed
658
  if (!config_.max_bin_by_feature.empty()) {
659
660
    CHECK_EQ(static_cast<size_t>(num_col), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
Belinda Trotta's avatar
Belinda Trotta committed
661
  }
662
663
664
665
666

  // get forced split
  std::string forced_bins_path = config_.forcedbins_filename;
  std::vector<std::vector<double>> forced_bin_bounds = DatasetLoader::GetForcedBins(forced_bins_path, num_col, categorical_features_);

Guolin Ke's avatar
Guolin Ke committed
667
  const data_size_t filter_cnt = static_cast<data_size_t>(
Guolin Ke's avatar
Guolin Ke committed
668
    static_cast<double>(config_.min_data_in_leaf * total_sample_size) / num_data);
669
670
671
672
673
674
675
676
677
678
679
680
681
  if (Network::num_machines() == 1) {
    // if only one machine, find bin locally
    OMP_INIT_EX();
    #pragma omp parallel for schedule(guided)
    for (int i = 0; i < num_col; ++i) {
      OMP_LOOP_EX_BEGIN();
      if (ignore_features_.count(i) > 0) {
        bin_mappers[i] = nullptr;
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(i)) {
        bin_type = BinType::CategoricalBin;
682
683
684
685
        bool feat_is_unconstrained = ((config_.monotone_constraints.size() == 0) || (config_.monotone_constraints[i] == 0));
        if (!feat_is_unconstrained) {
            Log::Fatal("The output cannot be monotone with respect to categorical features");
        }
686
687
      }
      bin_mappers[i].reset(new BinMapper());
Belinda Trotta's avatar
Belinda Trotta committed
688
689
      if (config_.max_bin_by_feature.empty()) {
        bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
690
                                config_.max_bin, config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter,
691
692
                                bin_type, config_.use_missing, config_.zero_as_missing,
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
693
694
      } else {
        bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
695
                                config_.max_bin_by_feature[i], config_.min_data_in_bin,
696
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
697
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
698
      }
699
700
701
702
703
704
705
706
707
708
709
710
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();
  } else {
    // if have multi-machines, need to find bin distributed
    // different machines will find bin for different features
    int num_machines = Network::num_machines();
    int rank = Network::rank();
    // start and len will store the process feature indices for different machines
    // machine i will find bins for features in [ start[i], start[i] + len[i] )
    std::vector<int> start(num_machines);
    std::vector<int> len(num_machines);
711
    int step = (num_total_features + num_machines - 1) / num_machines;
712
713
714
715
    if (step < 1) { step = 1; }

    start[0] = 0;
    for (int i = 0; i < num_machines - 1; ++i) {
716
      len[i] = std::min(step, num_total_features - start[i]);
717
718
      start[i + 1] = start[i] + len[i];
    }
719
    len[num_machines - 1] = num_total_features - start[num_machines - 1];
720
721
722
723
724
725
726
727
728
729
730
731
    OMP_INIT_EX();
    #pragma omp parallel for schedule(guided)
    for (int i = 0; i < len[rank]; ++i) {
      OMP_LOOP_EX_BEGIN();
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(start[rank] + i)) {
        bin_type = BinType::CategoricalBin;
      }
      bin_mappers[i].reset(new BinMapper());
732
733
734
      if (num_col <= start[rank] + i) {
        continue;
      }
Belinda Trotta's avatar
Belinda Trotta committed
735
      if (config_.max_bin_by_feature.empty()) {
736
737
        bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
                                total_sample_size, config_.max_bin, config_.min_data_in_bin,
738
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
739
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
740
      } else {
741
742
        bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
                                total_sample_size, config_.max_bin_by_feature[start[rank] + i],
743
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
744
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
745
      }
746
747
      OMP_LOOP_EX_END();
    }
Guolin Ke's avatar
Guolin Ke committed
748
    OMP_THROW_EX();
749
    comm_size_t self_buf_size = 0;
750
    for (int i = 0; i < len[rank]; ++i) {
751
752
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
753
      }
754
      self_buf_size += static_cast<comm_size_t>(bin_mappers[i]->SizesInByte());
Guolin Ke's avatar
Guolin Ke committed
755
    }
756
757
    std::vector<char> input_buffer(self_buf_size);
    auto cp_ptr = input_buffer.data();
758
759
760
761
    for (int i = 0; i < len[rank]; ++i) {
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
762
763
      bin_mappers[i]->CopyTo(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
764
765
766
      // free
      bin_mappers[i].reset(nullptr);
    }
767
768
769
770
    std::vector<comm_size_t> size_len = Network::GlobalArray(self_buf_size);
    std::vector<comm_size_t> size_start(num_machines, 0);
    for (int i = 1; i < num_machines; ++i) {
      size_start[i] = size_start[i - 1] + size_len[i - 1];
771
    }
772
773
    comm_size_t total_buffer_size = size_start[num_machines - 1] + size_len[num_machines - 1];
    std::vector<char> output_buffer(total_buffer_size);
774
    // gather global feature bin mappers
775
776
    Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), total_buffer_size);
    cp_ptr = output_buffer.data();
777
    // restore features bins from buffer
778
    for (int i = 0; i < num_total_features; ++i) {
779
780
781
782
783
      if (ignore_features_.count(i) > 0) {
        bin_mappers[i] = nullptr;
        continue;
      }
      bin_mappers[i].reset(new BinMapper());
784
785
      bin_mappers[i]->CopyFrom(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
786
    }
Guolin Ke's avatar
Guolin Ke committed
787
  }
Guolin Ke's avatar
Guolin Ke committed
788
  auto dataset = std::unique_ptr<Dataset>(new Dataset(num_data));
Guolin Ke's avatar
Guolin Ke committed
789
  dataset->Construct(&bin_mappers, num_total_features, forced_bin_bounds, sample_indices, sample_values, num_per_col, num_col, total_sample_size, config_);
790
791
792
  if (dataset->has_raw()) {
    dataset->ResizeRaw(num_data);
  }
793
  dataset->set_feature_names(feature_names_);
Guolin Ke's avatar
Guolin Ke committed
794
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
795
}
Guolin Ke's avatar
Guolin Ke committed
796
797
798
799


// ---- private functions ----

800
void DatasetLoader::CheckDataset(const Dataset* dataset, bool is_load_from_binary) {
Guolin Ke's avatar
Guolin Ke committed
801
  if (dataset->num_data_ <= 0) {
Guolin Ke's avatar
Guolin Ke committed
802
    Log::Fatal("Data file %s is empty", dataset->data_filename_.c_str());
Guolin Ke's avatar
Guolin Ke committed
803
  }
804
805
  if (dataset->feature_names_.size() != static_cast<size_t>(dataset->num_total_features_)) {
    Log::Fatal("Size of feature name error, should be %d, got %d", dataset->num_total_features_,
806
               static_cast<int>(dataset->feature_names_.size()));
807
  }
Guolin Ke's avatar
Guolin Ke committed
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
  bool is_feature_order_by_group = true;
  int last_group = -1;
  int last_sub_feature = -1;
  // if features are ordered, not need to use hist_buf
  for (int i = 0; i < dataset->num_features_; ++i) {
    int group = dataset->feature2group_[i];
    int sub_feature = dataset->feature2subfeature_[i];
    if (group < last_group) {
      is_feature_order_by_group = false;
    } else if (group == last_group) {
      if (sub_feature <= last_sub_feature) {
        is_feature_order_by_group = false;
        break;
      }
    }
    last_group = group;
    last_sub_feature = sub_feature;
  }
  if (!is_feature_order_by_group) {
827
    Log::Fatal("Features in dataset should be ordered by group");
Guolin Ke's avatar
Guolin Ke committed
828
  }
829
830
831

  if (is_load_from_binary) {
    if (dataset->max_bin_ != config_.max_bin) {
832
833
      Log::Fatal("Dataset was constructed with parameter max_bin=%d. It cannot be changed to %d when loading from binary file.",
                 dataset->max_bin_, config_.max_bin);
834
835
    }
    if (dataset->min_data_in_bin_ != config_.min_data_in_bin) {
836
837
      Log::Fatal("Dataset was constructed with parameter min_data_in_bin=%d. It cannot be changed to %d when loading from binary file.",
                 dataset->min_data_in_bin_, config_.min_data_in_bin);
838
839
    }
    if (dataset->use_missing_ != config_.use_missing) {
840
841
      Log::Fatal("Dataset was constructed with parameter use_missing=%d. It cannot be changed to %d when loading from binary file.",
                 dataset->use_missing_, config_.use_missing);
842
843
    }
    if (dataset->zero_as_missing_ != config_.zero_as_missing) {
844
845
      Log::Fatal("Dataset was constructed with parameter zero_as_missing=%d. It cannot be changed to %d when loading from binary file.",
                 dataset->zero_as_missing_, config_.zero_as_missing);
846
847
    }
    if (dataset->bin_construct_sample_cnt_ != config_.bin_construct_sample_cnt) {
848
849
      Log::Fatal("Dataset was constructed with parameter bin_construct_sample_cnt=%d. It cannot be changed to %d when loading from binary file.",
                 dataset->bin_construct_sample_cnt_, config_.bin_construct_sample_cnt);
850
851
852
853
    }
    if ((dataset->max_bin_by_feature_.size() != config_.max_bin_by_feature.size()) ||
        !std::equal(dataset->max_bin_by_feature_.begin(), dataset->max_bin_by_feature_.end(),
            config_.max_bin_by_feature.begin())) {
854
      Log::Fatal("Parameter max_bin_by_feature cannot be changed when loading from binary file.");
855
856
    }

857
    if (config_.label_column != "") {
858
      Log::Warning("Parameter label_column works only in case of loading data directly from text file. It will be ignored when loading from binary file.");
859
860
    }
    if (config_.weight_column != "") {
861
      Log::Warning("Parameter weight_column works only in case of loading data directly from text file. It will be ignored when loading from binary file.");
862
863
    }
    if (config_.group_column != "") {
864
      Log::Warning("Parameter group_column works only in case of loading data directly from text file. It will be ignored when loading from binary file.");
865
866
    }
    if (config_.ignore_column != "") {
867
      Log::Warning("Parameter ignore_column works only in case of loading data directly from text file. It will be ignored when loading from binary file.");
868
    }
869
    if (config_.two_round) {
870
      Log::Warning("Parameter two_round works only in case of loading data directly from text file. It will be ignored when loading from binary file.");
871
872
    }
    if (config_.header) {
873
      Log::Warning("Parameter header works only in case of loading data directly from text file. It will be ignored when loading from binary file.");
874
    }
875
  }
Guolin Ke's avatar
Guolin Ke committed
876
877
878
}

std::vector<std::string> DatasetLoader::LoadTextDataToMemory(const char* filename, const Metadata& metadata,
879
880
                                                             int rank, int num_machines, int* num_global_data,
                                                             std::vector<data_size_t>* used_data_indices) {
881
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
882
  used_data_indices->clear();
Guolin Ke's avatar
Guolin Ke committed
883
  if (num_machines == 1 || config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
884
885
886
887
888
889
890
891
892
    // read all lines
    *num_global_data = text_reader.ReadAllLines();
  } else {  // need partition data
            // get query data
    const data_size_t* query_boundaries = metadata.query_boundaries();

    if (query_boundaries == nullptr) {
      // if not contain query data, minimal sample unit is one record
      *num_global_data = text_reader.ReadAndFilterLines([this, rank, num_machines](data_size_t) {
Guolin Ke's avatar
Guolin Ke committed
893
        if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
894
895
896
897
898
899
900
901
902
903
904
905
906
907
          return true;
        } else {
          return false;
        }
      }, used_data_indices);
    } else {
      // if contain query data, minimal sample unit is one query
      data_size_t num_queries = metadata.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      *num_global_data = text_reader.ReadAndFilterLines(
        [this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries]
      (data_size_t line_idx) {
        if (qid >= num_queries) {
908
909
          Log::Fatal("Current query exceeds the range of the query file,\n"
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
910
911
912
913
        }
        if (line_idx >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
914
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
915
916
917
918
919
920
921
922
923
924
925
926
            is_query_used = true;
          }
          ++qid;
        }
        return is_query_used;
      }, used_data_indices);
    }
  }
  return std::move(text_reader.Lines());
}

std::vector<std::string> DatasetLoader::SampleTextDataFromMemory(const std::vector<std::string>& data) {
Guolin Ke's avatar
Guolin Ke committed
927
  int sample_cnt = config_.bin_construct_sample_cnt;
928
929
  if (static_cast<size_t>(sample_cnt) > data.size()) {
    sample_cnt = static_cast<int>(data.size());
930
  }
931
  auto sample_indices = random_.Sample(static_cast<int>(data.size()), sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
932
  std::vector<std::string> out(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
933
934
  for (size_t i = 0; i < sample_indices.size(); ++i) {
    const size_t idx = sample_indices[i];
Guolin Ke's avatar
Guolin Ke committed
935
    out[i] = data[idx];
Guolin Ke's avatar
Guolin Ke committed
936
937
938
939
  }
  return out;
}

940
941
942
std::vector<std::string> DatasetLoader::SampleTextDataFromFile(const char* filename, const Metadata& metadata,
                                                               int rank, int num_machines, int* num_global_data,
                                                               std::vector<data_size_t>* used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
943
  const data_size_t sample_cnt = static_cast<data_size_t>(config_.bin_construct_sample_cnt);
944
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
945
  std::vector<std::string> out_data;
Guolin Ke's avatar
Guolin Ke committed
946
  if (num_machines == 1 || config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
947
    *num_global_data = static_cast<data_size_t>(text_reader.SampleFromFile(&random_, sample_cnt, &out_data));
Guolin Ke's avatar
Guolin Ke committed
948
949
950
951
952
953
954
  } else {  // need partition data
            // get query data
    const data_size_t* query_boundaries = metadata.query_boundaries();
    if (query_boundaries == nullptr) {
      // if not contain query file, minimal sample unit is one record
      *num_global_data = text_reader.SampleAndFilterFromFile([this, rank, num_machines]
      (data_size_t) {
Guolin Ke's avatar
Guolin Ke committed
955
        if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
956
957
958
959
          return true;
        } else {
          return false;
        }
Guolin Ke's avatar
Guolin Ke committed
960
      }, used_data_indices, &random_, sample_cnt, &out_data);
Guolin Ke's avatar
Guolin Ke committed
961
962
963
964
965
966
967
968
969
    } else {
      // if contain query file, minimal sample unit is one query
      data_size_t num_queries = metadata.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      *num_global_data = text_reader.SampleAndFilterFromFile(
        [this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries]
      (data_size_t line_idx) {
        if (qid >= num_queries) {
970
971
          Log::Fatal("Query id exceeds the range of the query file, "
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
972
973
974
975
        }
        if (line_idx >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
976
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
977
978
979
980
981
            is_query_used = true;
          }
          ++qid;
        }
        return is_query_used;
Guolin Ke's avatar
Guolin Ke committed
982
      }, used_data_indices, &random_, sample_cnt, &out_data);
Guolin Ke's avatar
Guolin Ke committed
983
984
985
986
987
    }
  }
  return out_data;
}

988
989
990
void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
                                                    const std::vector<std::string>& sample_data,
                                                    const Parser* parser, Dataset* dataset) {
991
  auto t1 = std::chrono::high_resolution_clock::now();
Guolin Ke's avatar
Guolin Ke committed
992
  std::vector<std::vector<double>> sample_values;
Guolin Ke's avatar
Guolin Ke committed
993
  std::vector<std::vector<int>> sample_indices;
Guolin Ke's avatar
Guolin Ke committed
994
995
  std::vector<std::pair<int, double>> oneline_features;
  double label;
Guolin Ke's avatar
Guolin Ke committed
996
  for (int i = 0; i < static_cast<int>(sample_data.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
997
998
999
1000
    oneline_features.clear();
    // parse features
    parser->ParseOneLine(sample_data[i].c_str(), &oneline_features, &label);
    for (std::pair<int, double>& inner_data : oneline_features) {
1001
      if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
Guolin Ke's avatar
Guolin Ke committed
1002
1003
        sample_values.resize(inner_data.first + 1);
        sample_indices.resize(inner_data.first + 1);
1004
      }
Guolin Ke's avatar
Guolin Ke committed
1005
      if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
Guolin Ke's avatar
Guolin Ke committed
1006
1007
        sample_values[inner_data.first].emplace_back(inner_data.second);
        sample_indices[inner_data.first].emplace_back(i);
Guolin Ke's avatar
Guolin Ke committed
1008
1009
1010
1011
      }
    }
  }

Guolin Ke's avatar
Guolin Ke committed
1012
  dataset->feature_groups_.clear();
1013
1014
1015
1016
1017
  dataset->num_total_features_ = std::max(static_cast<int>(sample_values.size()), parser->NumFeatures());
  if (num_machines > 1) {
    dataset->num_total_features_ = Network::GlobalSyncUpByMax(dataset->num_total_features_);
  }
  if (!feature_names_.empty()) {
1018
    CHECK_EQ(dataset->num_total_features_, static_cast<int>(feature_names_.size()));
1019
  }
Guolin Ke's avatar
Guolin Ke committed
1020

Belinda Trotta's avatar
Belinda Trotta committed
1021
  if (!config_.max_bin_by_feature.empty()) {
1022
1023
    CHECK_EQ(static_cast<size_t>(dataset->num_total_features_), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
Belinda Trotta's avatar
Belinda Trotta committed
1024
1025
  }

1026
1027
  // get forced split
  std::string forced_bins_path = config_.forcedbins_filename;
1028
1029
  std::vector<std::vector<double>> forced_bin_bounds = DatasetLoader::GetForcedBins(forced_bins_path,
                                                                                    dataset->num_total_features_,
1030
1031
                                                                                    categorical_features_);

Guolin Ke's avatar
Guolin Ke committed
1032
  // check the range of label_idx, weight_idx and group_idx
1033
1034
1035
1036
1037
  // skip label check if user input parser config file,
  // because label id is got from raw features while dataset features are consistent with customized parser.
  if (dataset->parser_config_str_.empty()) {
    CHECK(label_idx_ >= 0 && label_idx_ <= dataset->num_total_features_);
  }
Guolin Ke's avatar
Guolin Ke committed
1038
1039
1040
1041
  CHECK(weight_idx_ < 0 || weight_idx_ < dataset->num_total_features_);
  CHECK(group_idx_ < 0 || group_idx_ < dataset->num_total_features_);

  // fill feature_names_ if not header
Guolin Ke's avatar
Guolin Ke committed
1042
  if (feature_names_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1043
1044
1045
1046
1047
1048
    for (int i = 0; i < dataset->num_total_features_; ++i) {
      std::stringstream str_buf;
      str_buf << "Column_" << i;
      feature_names_.push_back(str_buf.str());
    }
  }
1049
  dataset->set_feature_names(feature_names_);
Guolin Ke's avatar
Guolin Ke committed
1050
  std::vector<std::unique_ptr<BinMapper>> bin_mappers(dataset->num_total_features_);
Guolin Ke's avatar
Guolin Ke committed
1051
  const data_size_t filter_cnt = static_cast<data_size_t>(
Guolin Ke's avatar
Guolin Ke committed
1052
    static_cast<double>(config_.min_data_in_leaf* sample_data.size()) / dataset->num_data_);
Guolin Ke's avatar
Guolin Ke committed
1053
1054
1055
  // start find bins
  if (num_machines == 1) {
    // if only one machine, find bin locally
1056
    OMP_INIT_EX();
1057
    #pragma omp parallel for schedule(guided)
Guolin Ke's avatar
Guolin Ke committed
1058
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
1059
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1060
      if (ignore_features_.count(i) > 0) {
Guolin Ke's avatar
Guolin Ke committed
1061
        bin_mappers[i] = nullptr;
Guolin Ke's avatar
Guolin Ke committed
1062
1063
        continue;
      }
1064
1065
1066
1067
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(i)) {
        bin_type = BinType::CategoricalBin;
      }
Guolin Ke's avatar
Guolin Ke committed
1068
      bin_mappers[i].reset(new BinMapper());
Belinda Trotta's avatar
Belinda Trotta committed
1069
1070
      if (config_.max_bin_by_feature.empty()) {
        bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
1071
                                sample_data.size(), config_.max_bin, config_.min_data_in_bin,
1072
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
1073
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
1074
1075
      } else {
        bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
1076
                                sample_data.size(), config_.max_bin_by_feature[i],
1077
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
1078
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
1079
      }
1080
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1081
    }
1082
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1083
1084
  } else {
    // start and len will store the process feature indices for different machines
1085
    // machine i will find bins for features in [ start[i], start[i] + len[i] )
Guolin Ke's avatar
Guolin Ke committed
1086
1087
    std::vector<int> start(num_machines);
    std::vector<int> len(num_machines);
1088
    int step = (dataset->num_total_features_ + num_machines - 1) / num_machines;
Guolin Ke's avatar
Guolin Ke committed
1089
1090
1091
1092
    if (step < 1) { step = 1; }

    start[0] = 0;
    for (int i = 0; i < num_machines - 1; ++i) {
1093
      len[i] = std::min(step, dataset->num_total_features_ - start[i]);
Guolin Ke's avatar
Guolin Ke committed
1094
1095
      start[i + 1] = start[i] + len[i];
    }
1096
    len[num_machines - 1] = dataset->num_total_features_ - start[num_machines - 1];
1097
    OMP_INIT_EX();
1098
    #pragma omp parallel for schedule(guided)
1099
    for (int i = 0; i < len[rank]; ++i) {
1100
      OMP_LOOP_EX_BEGIN();
1101
1102
1103
1104
1105
1106
1107
1108
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(start[rank] + i)) {
        bin_type = BinType::CategoricalBin;
      }
      bin_mappers[i].reset(new BinMapper());
Nikita Titov's avatar
Nikita Titov committed
1109
      if (static_cast<int>(sample_values.size()) <= start[rank] + i) {
1110
1111
        continue;
      }
Belinda Trotta's avatar
Belinda Trotta committed
1112
      if (config_.max_bin_by_feature.empty()) {
1113
        bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
Belinda Trotta's avatar
Belinda Trotta committed
1114
                                static_cast<int>(sample_values[start[rank] + i].size()),
1115
                                sample_data.size(), config_.max_bin, config_.min_data_in_bin,
1116
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
1117
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
1118
      } else {
1119
        bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
Belinda Trotta's avatar
Belinda Trotta committed
1120
                                static_cast<int>(sample_values[start[rank] + i].size()),
1121
                                sample_data.size(), config_.max_bin_by_feature[i],
1122
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type,
1123
                                config_.use_missing, config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
1124
      }
1125
      OMP_LOOP_EX_END();
1126
    }
1127
    OMP_THROW_EX();
1128
    comm_size_t self_buf_size = 0;
Guolin Ke's avatar
Guolin Ke committed
1129
    for (int i = 0; i < len[rank]; ++i) {
1130
1131
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
Guolin Ke's avatar
Guolin Ke committed
1132
      }
1133
      self_buf_size += static_cast<comm_size_t>(bin_mappers[i]->SizesInByte());
Guolin Ke's avatar
Guolin Ke committed
1134
    }
1135
1136
    std::vector<char> input_buffer(self_buf_size);
    auto cp_ptr = input_buffer.data();
Guolin Ke's avatar
Guolin Ke committed
1137
    for (int i = 0; i < len[rank]; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1138
1139
1140
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
1141
1142
      bin_mappers[i]->CopyTo(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
1143
1144
      // free
      bin_mappers[i].reset(nullptr);
Guolin Ke's avatar
Guolin Ke committed
1145
    }
1146
1147
1148
1149
    std::vector<comm_size_t> size_len = Network::GlobalArray(self_buf_size);
    std::vector<comm_size_t> size_start(num_machines, 0);
    for (int i = 1; i < num_machines; ++i) {
      size_start[i] = size_start[i - 1] + size_len[i - 1];
Guolin Ke's avatar
Guolin Ke committed
1150
    }
1151
1152
    comm_size_t total_buffer_size = size_start[num_machines - 1] + size_len[num_machines - 1];
    std::vector<char> output_buffer(total_buffer_size);
Guolin Ke's avatar
Guolin Ke committed
1153
    // gather global feature bin mappers
1154
1155
    Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), total_buffer_size);
    cp_ptr = output_buffer.data();
Guolin Ke's avatar
Guolin Ke committed
1156
    // restore features bins from buffer
1157
    for (int i = 0; i < dataset->num_total_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1158
      if (ignore_features_.count(i) > 0) {
Guolin Ke's avatar
Guolin Ke committed
1159
        bin_mappers[i] = nullptr;
Guolin Ke's avatar
Guolin Ke committed
1160
1161
        continue;
      }
Guolin Ke's avatar
Guolin Ke committed
1162
      bin_mappers[i].reset(new BinMapper());
1163
1164
      bin_mappers[i]->CopyFrom(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
Guolin Ke's avatar
Guolin Ke committed
1165
1166
    }
  }
1167
  dataset->Construct(&bin_mappers, dataset->num_total_features_, forced_bin_bounds, Common::Vector2Ptr<int>(&sample_indices).data(),
Guolin Ke's avatar
Guolin Ke committed
1168
                     Common::Vector2Ptr<double>(&sample_values).data(),
1169
                     Common::VectorSize<int>(sample_indices).data(), static_cast<int>(sample_indices.size()), sample_data.size(), config_);
1170
  if (dataset->has_raw()) {
1171
    dataset->ResizeRaw(static_cast<int>(sample_data.size()));
1172
  }
1173
1174
1175
1176

  auto t2 = std::chrono::high_resolution_clock::now();
  Log::Info("Construct bin mappers from text data time %.2f seconds",
            std::chrono::duration<double, std::milli>(t2 - t1) * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
1177
1178
1179
}

/*! \brief Extract local features from memory */
Guolin Ke's avatar
Guolin Ke committed
1180
void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>* text_data, const Parser* parser, Dataset* dataset) {
Guolin Ke's avatar
Guolin Ke committed
1181
1182
  std::vector<std::pair<int, double>> oneline_features;
  double tmp_label = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
1183
  auto& ref_text_data = *text_data;
1184
  std::vector<float> feature_row(dataset->num_features_);
1185
  if (!predict_fun_) {
1186
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1187
    // if doesn't need to prediction with initial model
1188
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label, feature_row)
Guolin Ke's avatar
Guolin Ke committed
1189
    for (data_size_t i = 0; i < dataset->num_data_; ++i) {
1190
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1191
1192
1193
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
Guolin Ke's avatar
Guolin Ke committed
1194
      parser->ParseOneLine(ref_text_data[i].c_str(), &oneline_features, &tmp_label);
Guolin Ke's avatar
Guolin Ke committed
1195
      // set label
1196
      dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1197
      // free processed line:
Guolin Ke's avatar
Guolin Ke committed
1198
      ref_text_data[i].clear();
Guolin Ke's avatar
Guolin Ke committed
1199
1200
      // shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
      // text_reader_->Lines()[i].shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
1201
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1202
1203
      // push data
      for (auto& inner_data : oneline_features) {
1204
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1205
1206
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1207
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1208
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1209
1210
1211
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
          dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
1212
          if (dataset->has_raw()) {
1213
            feature_row[feature_idx] = static_cast<float>(inner_data.second);
1214
          }
Guolin Ke's avatar
Guolin Ke committed
1215
1216
        } else {
          if (inner_data.first == weight_idx_) {
1217
            dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1218
1219
1220
1221
1222
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
1223
1224
1225
1226
1227
1228
1229
1230
      if (dataset->has_raw()) {
        for (size_t j = 0; j < feature_row.size(); ++j) {
          int feat_ind = dataset->numeric_feature_map_[j];
          if (feat_ind >= 0) {
            dataset->raw_data_[feat_ind][i] = feature_row[j];
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1231
      dataset->FinishOneRow(tid, i, is_feature_added);
1232
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1233
    }
1234
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1235
  } else {
1236
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1237
    // if need to prediction with initial model
1238
    std::vector<double> init_score(dataset->num_data_ * num_class_);
1239
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label, feature_row)
Guolin Ke's avatar
Guolin Ke committed
1240
    for (data_size_t i = 0; i < dataset->num_data_; ++i) {
1241
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1242
1243
1244
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
Guolin Ke's avatar
Guolin Ke committed
1245
      parser->ParseOneLine(ref_text_data[i].c_str(), &oneline_features, &tmp_label);
Guolin Ke's avatar
Guolin Ke committed
1246
      // set initial score
Guolin Ke's avatar
Guolin Ke committed
1247
1248
      std::vector<double> oneline_init_score(num_class_);
      predict_fun_(oneline_features, oneline_init_score.data());
1249
      for (int k = 0; k < num_class_; ++k) {
1250
        init_score[k * dataset->num_data_ + i] = static_cast<double>(oneline_init_score[k]);
Guolin Ke's avatar
Guolin Ke committed
1251
1252
      }
      // set label
1253
      dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1254
      // free processed line:
1255
      ref_text_data[i].clear();
Andrew Ziem's avatar
Andrew Ziem committed
1256
      // shrink_to_fit will be very slow in Linux, and seems not free memory, disable for now
Guolin Ke's avatar
Guolin Ke committed
1257
1258
      // text_reader_->Lines()[i].shrink_to_fit();
      // push data
Guolin Ke's avatar
Guolin Ke committed
1259
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1260
      for (auto& inner_data : oneline_features) {
1261
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1262
1263
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1264
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1265
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1266
1267
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
1268
          dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
1269
          if (dataset->has_raw()) {
1270
            feature_row[feature_idx] = static_cast<float>(inner_data.second);
1271
          }
Guolin Ke's avatar
Guolin Ke committed
1272
1273
        } else {
          if (inner_data.first == weight_idx_) {
1274
            dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1275
1276
1277
1278
1279
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1280
      dataset->FinishOneRow(tid, i, is_feature_added);
1281
1282
1283
1284
1285
1286
1287
1288
      if (dataset->has_raw()) {
        for (size_t j = 0; j < feature_row.size(); ++j) {
          int feat_ind = dataset->numeric_feature_map_[j];
          if (feat_ind >= 0) {
            dataset->raw_data_[feat_ind][i] = feature_row[j];
          }
        }
      }
1289
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1290
    }
1291
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1292
    // metadata_ will manage space of init_score
1293
    dataset->metadata_.SetInitScore(init_score.data(), dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1294
  }
Guolin Ke's avatar
Guolin Ke committed
1295
  dataset->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1296
  // text data can be free after loaded feature values
Guolin Ke's avatar
Guolin Ke committed
1297
  text_data->clear();
Guolin Ke's avatar
Guolin Ke committed
1298
1299
1300
}

/*! \brief Extract local features from file */
1301
1302
void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser* parser,
                                            const std::vector<data_size_t>& used_data_indices, Dataset* dataset) {
1303
  std::vector<double> init_score;
1304
  if (predict_fun_) {
1305
    init_score = std::vector<double>(dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1306
1307
1308
1309
1310
1311
  }
  std::function<void(data_size_t, const std::vector<std::string>&)> process_fun =
    [this, &init_score, &parser, &dataset]
  (data_size_t start_idx, const std::vector<std::string>& lines) {
    std::vector<std::pair<int, double>> oneline_features;
    double tmp_label = 0.0f;
1312
    std::vector<float> feature_row(dataset->num_features_);
1313
    OMP_INIT_EX();
1314
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label, feature_row)
Guolin Ke's avatar
Guolin Ke committed
1315
    for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) {
1316
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1317
1318
1319
1320
1321
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
      parser->ParseOneLine(lines[i].c_str(), &oneline_features, &tmp_label);
      // set initial score
Guolin Ke's avatar
Guolin Ke committed
1322
      if (!init_score.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1323
1324
        std::vector<double> oneline_init_score(num_class_);
        predict_fun_(oneline_features, oneline_init_score.data());
1325
        for (int k = 0; k < num_class_; ++k) {
1326
          init_score[k * dataset->num_data_ + start_idx + i] = static_cast<double>(oneline_init_score[k]);
Guolin Ke's avatar
Guolin Ke committed
1327
1328
1329
        }
      }
      // set label
1330
      dataset->metadata_.SetLabelAt(start_idx + i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1331
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1332
1333
      // push data
      for (auto& inner_data : oneline_features) {
1334
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1335
1336
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1337
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1338
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1339
1340
1341
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
          dataset->feature_groups_[group]->PushData(tid, sub_feature, start_idx + i, inner_data.second);
1342
          if (dataset->has_raw()) {
1343
            feature_row[feature_idx] = static_cast<float>(inner_data.second);
1344
          }
Guolin Ke's avatar
Guolin Ke committed
1345
1346
        } else {
          if (inner_data.first == weight_idx_) {
1347
            dataset->metadata_.SetWeightAt(start_idx + i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1348
1349
1350
1351
1352
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(start_idx + i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
1353
1354
1355
1356
1357
1358
1359
1360
      if (dataset->has_raw()) {
        for (size_t j = 0; j < feature_row.size(); ++j) {
          int feat_ind = dataset->numeric_feature_map_[j];
          if (feat_ind >= 0) {
            dataset->raw_data_[feat_ind][i] = feature_row[j];
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1361
      dataset->FinishOneRow(tid, i, is_feature_added);
1362
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1363
    }
1364
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1365
  };
1366
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
1367
  if (!used_data_indices.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1368
1369
1370
1371
1372
1373
1374
1375
    // only need part of data
    text_reader.ReadPartAndProcessParallel(used_data_indices, process_fun);
  } else {
    // need full data
    text_reader.ReadAllAndProcessParallel(process_fun);
  }

  // metadata_ will manage space of init_score
Guolin Ke's avatar
Guolin Ke committed
1376
  if (!init_score.empty()) {
1377
    dataset->metadata_.SetInitScore(init_score.data(), dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1378
  }
Guolin Ke's avatar
Guolin Ke committed
1379
  dataset->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1380
1381
1382
}

/*! \brief Check can load from binary file */
1383
std::string DatasetLoader::CheckCanLoadFromBin(const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
1384
1385
1386
  std::string bin_filename(filename);
  bin_filename.append(".bin");

1387
  auto reader = VirtualFileReader::Make(bin_filename.c_str());
Guolin Ke's avatar
Guolin Ke committed
1388

1389
  if (!reader->Init()) {
1390
    bin_filename = std::string(filename);
1391
1392
    reader = VirtualFileReader::Make(bin_filename.c_str());
    if (!reader->Init()) {
1393
      Log::Fatal("Cannot open data file %s", bin_filename.c_str());
1394
    }
1395
  }
1396
1397
1398
1399
1400

  size_t buffer_size = 256;
  auto buffer = std::vector<char>(buffer_size);
  // read size of token
  size_t size_of_token = std::strlen(Dataset::binary_file_token);
1401
  size_t read_cnt = reader->Read(buffer.data(), size_of_token);
1402
1403
  if (read_cnt == size_of_token
      && std::string(buffer.data()) == std::string(Dataset::binary_file_token)) {
1404
    return bin_filename;
Guolin Ke's avatar
Guolin Ke committed
1405
  } else {
1406
    return std::string();
Guolin Ke's avatar
Guolin Ke committed
1407
1408
1409
  }
}

1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
std::vector<std::vector<double>> DatasetLoader::GetForcedBins(std::string forced_bins_path, int num_total_features,
                                                              const std::unordered_set<int>& categorical_features) {
  std::vector<std::vector<double>> forced_bins(num_total_features, std::vector<double>());
  if (forced_bins_path != "") {
    std::ifstream forced_bins_stream(forced_bins_path.c_str());
    if (forced_bins_stream.fail()) {
      Log::Warning("Could not open %s. Will ignore.", forced_bins_path.c_str());
    } else {
      std::stringstream buffer;
      buffer << forced_bins_stream.rdbuf();
      std::string err;
Guolin Ke's avatar
Guolin Ke committed
1421
      Json forced_bins_json = Json::parse(buffer.str(), &err);
1422
1423
1424
1425
      CHECK(forced_bins_json.is_array());
      std::vector<Json> forced_bins_arr = forced_bins_json.array_items();
      for (size_t i = 0; i < forced_bins_arr.size(); ++i) {
        int feature_num = forced_bins_arr[i]["feature"].int_value();
Nikita Titov's avatar
Nikita Titov committed
1426
        CHECK_LT(feature_num, num_total_features);
1427
        if (categorical_features.count(feature_num)) {
1428
          Log::Warning("Feature %d is categorical. Will ignore forced bins for this feature.", feature_num);
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
        } else {
          std::vector<Json> bounds_arr = forced_bins_arr[i]["bin_upper_bound"].array_items();
          for (size_t j = 0; j < bounds_arr.size(); ++j) {
            forced_bins[feature_num].push_back(bounds_arr[j].number_value());
          }
        }
      }
      // remove duplicates
      for (int i = 0; i < num_total_features; ++i) {
        auto new_end = std::unique(forced_bins[i].begin(), forced_bins[i].end());
        forced_bins[i].erase(new_end, forced_bins[i].end());
      }
    }
  }
  return forced_bins;
}

1446
}  // namespace LightGBM