dataset_loader.cpp 62.4 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
#include <LightGBM/dataset_loader.h>

Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/network.h>
8
#include <LightGBM/utils/array_args.h>
9
#include <LightGBM/utils/json11.h>
10
11
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
12

13
#include <chrono>
14
15
#include <fstream>

Guolin Ke's avatar
Guolin Ke committed
16
17
namespace LightGBM {

18
19
using json11::Json;

Guolin Ke's avatar
Guolin Ke committed
20
21
DatasetLoader::DatasetLoader(const Config& io_config, const PredictFunction& predict_fun, int num_class, const char* filename)
  :config_(io_config), random_(config_.data_random_seed), predict_fun_(predict_fun), num_class_(num_class) {
Guolin Ke's avatar
Guolin Ke committed
22
23
24
25
  label_idx_ = 0;
  weight_idx_ = NO_SPECIFIC;
  group_idx_ = NO_SPECIFIC;
  SetHeader(filename);
26
27
28
29
  store_raw_ = false;
  if (io_config.linear_tree) {
    store_raw_ = true;
  }
Guolin Ke's avatar
Guolin Ke committed
30
31
32
33
34
}

DatasetLoader::~DatasetLoader() {
}

Guolin Ke's avatar
Guolin Ke committed
35
void DatasetLoader::SetHeader(const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
36
  std::unordered_map<std::string, int> name2idx;
Guolin Ke's avatar
Guolin Ke committed
37
  std::string name_prefix("name:");
38
  if (filename != nullptr && CheckCanLoadFromBin(filename) == "") {
Guolin Ke's avatar
Guolin Ke committed
39
    TextReader<data_size_t> text_reader(filename, config_.header);
Guolin Ke's avatar
Guolin Ke committed
40

Guolin Ke's avatar
Guolin Ke committed
41
    // get column names
Guolin Ke's avatar
Guolin Ke committed
42
    if (config_.header) {
Guolin Ke's avatar
Guolin Ke committed
43
      std::string first_line = text_reader.first_line();
44
      feature_names_ = Common::Split(first_line.c_str(), "\t,");
45
46
47
48
49
50
51
52
53
54
55
56
    } else if (!config_.parser_config_file.empty()) {
      // support to get header from parser config, so could utilize following label name to id mapping logic.
      TextReader<data_size_t> parser_config_reader(config_.parser_config_file.c_str(), false);
      parser_config_reader.ReadAllLines();
      std::string parser_config_str = parser_config_reader.JoinedLines();
      if (!parser_config_str.empty()) {
        std::string header_in_parser_config = Common::GetFromParserConfig(parser_config_str, "header");
        if (!header_in_parser_config.empty()) {
          Log::Info("Get raw column names from parser config.");
          feature_names_ = Common::Split(header_in_parser_config.c_str(), "\t,");
        }
      }
Guolin Ke's avatar
Guolin Ke committed
57
58
    }

Guolin Ke's avatar
Guolin Ke committed
59
    // load label idx first
Guolin Ke's avatar
Guolin Ke committed
60
61
62
    if (config_.label_column.size() > 0) {
      if (Common::StartsWith(config_.label_column, name_prefix)) {
        std::string name = config_.label_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
63
64
65
66
67
68
69
70
71
72
        label_idx_ = -1;
        for (int i = 0; i < static_cast<int>(feature_names_.size()); ++i) {
          if (name == feature_names_[i]) {
            label_idx_ = i;
            break;
          }
        }
        if (label_idx_ >= 0) {
          Log::Info("Using column %s as label", name.c_str());
        } else {
73
74
          Log::Fatal("Could not find label column %s in data file \n"
                     "or data file doesn't contain header", name.c_str());
Guolin Ke's avatar
Guolin Ke committed
75
        }
Guolin Ke's avatar
Guolin Ke committed
76
      } else {
Guolin Ke's avatar
Guolin Ke committed
77
        if (!Common::AtoiAndCheck(config_.label_column.c_str(), &label_idx_)) {
78
79
80
          Log::Fatal("label_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
81
82
        }
        Log::Info("Using column number %d as label", label_idx_);
Guolin Ke's avatar
Guolin Ke committed
83
84
      }
    }
Guolin Ke's avatar
Guolin Ke committed
85

86
87
88
89
90
91
92
93
94
    if (!config_.parser_config_file.empty()) {
      // if parser config file exists, feature names may be changed after customized parser applied.
      // clear here so could use default filled feature names during dataset construction.
      // may improve by saving real feature names defined in parser in the future.
      if (!feature_names_.empty()) {
        feature_names_.clear();
      }
    }

Guolin Ke's avatar
Guolin Ke committed
95
    if (!feature_names_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
96
97
98
99
      // erase label column name
      feature_names_.erase(feature_names_.begin() + label_idx_);
      for (size_t i = 0; i < feature_names_.size(); ++i) {
        name2idx[feature_names_[i]] = static_cast<int>(i);
Guolin Ke's avatar
Guolin Ke committed
100
      }
Guolin Ke's avatar
Guolin Ke committed
101
102
103
    }

    // load ignore columns
Guolin Ke's avatar
Guolin Ke committed
104
105
106
    if (config_.ignore_column.size() > 0) {
      if (Common::StartsWith(config_.ignore_column, name_prefix)) {
        std::string names = config_.ignore_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
107
108
109
110
111
112
113
114
115
        for (auto name : Common::Split(names.c_str(), ',')) {
          if (name2idx.count(name) > 0) {
            int tmp = name2idx[name];
            ignore_features_.emplace(tmp);
          } else {
            Log::Fatal("Could not find ignore column %s in data file", name.c_str());
          }
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
116
        for (auto token : Common::Split(config_.ignore_column.c_str(), ',')) {
Guolin Ke's avatar
Guolin Ke committed
117
118
          int tmp = 0;
          if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
119
120
121
            Log::Fatal("ignore_column is not a number,\n"
                       "if you want to use a column name,\n"
                       "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
122
123
          }
          ignore_features_.emplace(tmp);
Guolin Ke's avatar
Guolin Ke committed
124
125
126
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
127
    // load weight idx
Guolin Ke's avatar
Guolin Ke committed
128
129
130
    if (config_.weight_column.size() > 0) {
      if (Common::StartsWith(config_.weight_column, name_prefix)) {
        std::string name = config_.weight_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
131
132
133
134
135
136
        if (name2idx.count(name) > 0) {
          weight_idx_ = name2idx[name];
          Log::Info("Using column %s as weight", name.c_str());
        } else {
          Log::Fatal("Could not find weight column %s in data file", name.c_str());
        }
Guolin Ke's avatar
Guolin Ke committed
137
      } else {
Guolin Ke's avatar
Guolin Ke committed
138
        if (!Common::AtoiAndCheck(config_.weight_column.c_str(), &weight_idx_)) {
139
140
141
          Log::Fatal("weight_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
142
143
        }
        Log::Info("Using column number %d as weight", weight_idx_);
Guolin Ke's avatar
Guolin Ke committed
144
      }
Guolin Ke's avatar
Guolin Ke committed
145
      ignore_features_.emplace(weight_idx_);
Guolin Ke's avatar
Guolin Ke committed
146
    }
Guolin Ke's avatar
Guolin Ke committed
147
    // load group idx
Guolin Ke's avatar
Guolin Ke committed
148
149
150
    if (config_.group_column.size() > 0) {
      if (Common::StartsWith(config_.group_column, name_prefix)) {
        std::string name = config_.group_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
151
152
153
154
155
156
157
        if (name2idx.count(name) > 0) {
          group_idx_ = name2idx[name];
          Log::Info("Using column %s as group/query id", name.c_str());
        } else {
          Log::Fatal("Could not find group/query column %s in data file", name.c_str());
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
158
        if (!Common::AtoiAndCheck(config_.group_column.c_str(), &group_idx_)) {
159
160
161
          Log::Fatal("group_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
162
163
164
165
        }
        Log::Info("Using column number %d as group/query id", group_idx_);
      }
      ignore_features_.emplace(group_idx_);
Guolin Ke's avatar
Guolin Ke committed
166
167
    }
  }
Guolin Ke's avatar
Guolin Ke committed
168
169
170
  if (config_.categorical_feature.size() > 0) {
    if (Common::StartsWith(config_.categorical_feature, name_prefix)) {
      std::string names = config_.categorical_feature.substr(name_prefix.size());
171
172
173
174
175
      for (auto name : Common::Split(names.c_str(), ',')) {
        if (name2idx.count(name) > 0) {
          int tmp = name2idx[name];
          categorical_features_.emplace(tmp);
        } else {
Guolin Ke's avatar
Guolin Ke committed
176
          Log::Fatal("Could not find categorical_feature %s in data file", name.c_str());
177
178
179
        }
      }
    } else {
Guolin Ke's avatar
Guolin Ke committed
180
      for (auto token : Common::Split(config_.categorical_feature.c_str(), ',')) {
181
182
        int tmp = 0;
        if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
Guolin Ke's avatar
Guolin Ke committed
183
          Log::Fatal("categorical_feature is not a number,\n"
184
185
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
186
187
188
189
190
        }
        categorical_features_.emplace(tmp);
      }
    }
  }
Guolin Ke's avatar
Guolin Ke committed
191
192
}

193
194
195
196
197
198
199
200
201
202
void CheckSampleSize(size_t sample_cnt, size_t num_data) {
  if (static_cast<double>(sample_cnt) / num_data < 0.2f &&
      sample_cnt < 100000) {
    Log::Warning(
        "Using too small ``bin_construct_sample_cnt`` may encounter "
        "unexpected "
        "errors and poor accuracy.");
  }
}

203
Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_machines) {
204
  // don't support query id in data file when using distributed training
Guolin Ke's avatar
Guolin Ke committed
205
  if (num_machines > 1 && !config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
206
    if (group_idx_ > 0) {
207
      Log::Fatal("Using a query id without pre-partitioning the data file is not supported for distributed training.\n"
208
                 "Please use an additional query file or pre-partition the data");
Guolin Ke's avatar
Guolin Ke committed
209
210
    }
  }
Guolin Ke's avatar
Guolin Ke committed
211
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
212
213
214
  if (store_raw_) {
    dataset->SetHasRaw(true);
  }
Guolin Ke's avatar
Guolin Ke committed
215
216
  data_size_t num_global_data = 0;
  std::vector<data_size_t> used_data_indices;
217
  auto bin_filename = CheckCanLoadFromBin(filename);
218
  bool is_load_from_binary = false;
219
  if (bin_filename.size() == 0) {
220
    dataset->parser_config_str_ = Parser::GenerateParserConfigStr(filename, config_.parser_config_file.c_str(), config_.header, label_idx_);
Chen Yufei's avatar
Chen Yufei committed
221
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_,
222
                                                               config_.precise_float_parser, dataset->parser_config_str_));
Guolin Ke's avatar
Guolin Ke committed
223
224
225
226
    if (parser == nullptr) {
      Log::Fatal("Could not recognize data format of %s", filename);
    }
    dataset->data_filename_ = filename;
Guolin Ke's avatar
Guolin Ke committed
227
    dataset->label_idx_ = label_idx_;
228
    dataset->metadata_.Init(filename);
Guolin Ke's avatar
Guolin Ke committed
229
    if (!config_.two_round) {
Guolin Ke's avatar
Guolin Ke committed
230
      // read data to memory
231
      auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
232
233
234
      dataset->num_data_ = static_cast<data_size_t>(text_data.size());
      // sample data
      auto sample_data = SampleTextDataFromMemory(text_data);
235
236
      CheckSampleSize(sample_data.size(),
                      static_cast<size_t>(dataset->num_data_));
237
      // construct feature bin mappers & clear sample data
Guolin Ke's avatar
Guolin Ke committed
238
      ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get());
239
      std::vector<std::string>().swap(sample_data);
240
241
242
      if (dataset->has_raw()) {
        dataset->ResizeRaw(dataset->num_data_);
      }
Guolin Ke's avatar
Guolin Ke committed
243
      // initialize label
244
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
245
      // extract features
Guolin Ke's avatar
Guolin Ke committed
246
      ExtractFeaturesFromMemory(&text_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
247
248
249
250
251
252
253
254
255
      text_data.clear();
    } else {
      // sample data from file
      auto sample_data = SampleTextDataFromFile(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
      if (used_data_indices.size() > 0) {
        dataset->num_data_ = static_cast<data_size_t>(used_data_indices.size());
      } else {
        dataset->num_data_ = num_global_data;
      }
256
257
      CheckSampleSize(sample_data.size(),
                      static_cast<size_t>(dataset->num_data_));
258
      // construct feature bin mappers & clear sample data
Guolin Ke's avatar
Guolin Ke committed
259
      ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get());
260
      std::vector<std::string>().swap(sample_data);
261
262
263
      if (dataset->has_raw()) {
        dataset->ResizeRaw(dataset->num_data_);
      }
Guolin Ke's avatar
Guolin Ke committed
264
      // initialize label
265
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
266
      Log::Info("Making second pass...");
Guolin Ke's avatar
Guolin Ke committed
267
      // extract features
Guolin Ke's avatar
Guolin Ke committed
268
      ExtractFeaturesFromFile(filename, parser.get(), used_data_indices, dataset.get());
Guolin Ke's avatar
Guolin Ke committed
269
270
271
    }
  } else {
    // load data from binary file
272
273
    is_load_from_binary = true;
    Log::Info("Load from binary file %s", bin_filename.c_str());
274
    dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), rank, num_machines, &num_global_data, &used_data_indices));
275
276
277
278
279
280
281
282
283
284
    dataset->device_type_ = config_.device_type;
    dataset->gpu_device_id_ = config_.gpu_device_id;
    #ifdef USE_CUDA_EXP
    if (config_.device_type == std::string("cuda_exp")) {
      dataset->CreateCUDAColumnData();
      dataset->metadata_.CreateCUDAMetadata(dataset->gpu_device_id_);
    } else {
      dataset->cuda_column_data_ = nullptr;
    }
    #endif  // USE_CUDA_EXP
Guolin Ke's avatar
Guolin Ke committed
285
286
287
288
  }
  // check meta data
  dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);
  // need to check training data
289
290
  CheckDataset(dataset.get(), is_load_from_binary);

Guolin Ke's avatar
Guolin Ke committed
291
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
292
293
}

294
Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, const Dataset* train_data) {
Guolin Ke's avatar
Guolin Ke committed
295
296
  data_size_t num_global_data = 0;
  std::vector<data_size_t> used_data_indices;
Guolin Ke's avatar
Guolin Ke committed
297
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
298
299
300
  if (store_raw_) {
    dataset->SetHasRaw(true);
  }
301
302
  auto bin_filename = CheckCanLoadFromBin(filename);
  if (bin_filename.size() == 0) {
Chen Yufei's avatar
Chen Yufei committed
303
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_,
304
                                                               config_.precise_float_parser, train_data->parser_config_str_));
Guolin Ke's avatar
Guolin Ke committed
305
306
307
308
    if (parser == nullptr) {
      Log::Fatal("Could not recognize data format of %s", filename);
    }
    dataset->data_filename_ = filename;
Guolin Ke's avatar
Guolin Ke committed
309
    dataset->label_idx_ = label_idx_;
310
    dataset->metadata_.Init(filename);
Guolin Ke's avatar
Guolin Ke committed
311
    if (!config_.two_round) {
Guolin Ke's avatar
Guolin Ke committed
312
313
314
315
      // read data in memory
      auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, 0, 1, &num_global_data, &used_data_indices);
      dataset->num_data_ = static_cast<data_size_t>(text_data.size());
      // initialize label
316
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
317
      dataset->CreateValid(train_data);
318
319
320
      if (dataset->has_raw()) {
        dataset->ResizeRaw(dataset->num_data_);
      }
Guolin Ke's avatar
Guolin Ke committed
321
      // extract features
Guolin Ke's avatar
Guolin Ke committed
322
      ExtractFeaturesFromMemory(&text_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
323
324
      text_data.clear();
    } else {
Guolin Ke's avatar
Guolin Ke committed
325
      TextReader<data_size_t> text_reader(filename, config_.header);
Guolin Ke's avatar
Guolin Ke committed
326
327
328
329
      // Get number of lines of data file
      dataset->num_data_ = static_cast<data_size_t>(text_reader.CountLine());
      num_global_data = dataset->num_data_;
      // initialize label
330
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
331
      dataset->CreateValid(train_data);
332
333
334
      if (dataset->has_raw()) {
        dataset->ResizeRaw(dataset->num_data_);
      }
Guolin Ke's avatar
Guolin Ke committed
335
      // extract features
Guolin Ke's avatar
Guolin Ke committed
336
      ExtractFeaturesFromFile(filename, parser.get(), used_data_indices, dataset.get());
Guolin Ke's avatar
Guolin Ke committed
337
338
339
    }
  } else {
    // load data from binary file
340
    dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), 0, 1, &num_global_data, &used_data_indices));
Guolin Ke's avatar
Guolin Ke committed
341
342
343
344
  }
  // not need to check validation data
  // check meta data
  dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
345
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
346
347
}

348
349
350
Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* bin_filename,
                                        int rank, int num_machines, int* num_global_data,
                                        std::vector<data_size_t>* used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
351
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
352
  auto reader = VirtualFileReader::Make(bin_filename);
Guolin Ke's avatar
Guolin Ke committed
353
  dataset->data_filename_ = data_filename;
354
  if (!reader->Init()) {
Guolin Ke's avatar
Guolin Ke committed
355
356
357
358
359
    Log::Fatal("Could not read binary data from %s", bin_filename);
  }

  // buffer to read binary file
  size_t buffer_size = 16 * 1024 * 1024;
Guolin Ke's avatar
Guolin Ke committed
360
  auto buffer = std::vector<char>(buffer_size);
361

362
363
  // check token
  size_t size_of_token = std::strlen(Dataset::binary_file_token);
364
365
366
367
  size_t read_cnt = reader->Read(
      buffer.data(),
      VirtualFileWriter::AlignedSize(sizeof(char) * size_of_token));
  if (read_cnt < sizeof(char) * size_of_token) {
368
369
370
    Log::Fatal("Binary file error: token has the wrong size");
  }
  if (std::string(buffer.data()) != std::string(Dataset::binary_file_token)) {
371
    Log::Fatal("Input file is not LightGBM binary file");
372
  }
Guolin Ke's avatar
Guolin Ke committed
373
374

  // read size of header
375
  read_cnt = reader->Read(buffer.data(), sizeof(size_t));
Guolin Ke's avatar
Guolin Ke committed
376

377
  if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
378
379
380
    Log::Fatal("Binary file error: header has the wrong size");
  }

Guolin Ke's avatar
Guolin Ke committed
381
  size_t size_of_head = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
382
383
384
385

  // re-allocmate space if not enough
  if (size_of_head > buffer_size) {
    buffer_size = size_of_head;
Guolin Ke's avatar
Guolin Ke committed
386
    buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
387
388
  }
  // read header
389
  read_cnt = reader->Read(buffer.data(), size_of_head);
Guolin Ke's avatar
Guolin Ke committed
390
391
392
393
394

  if (read_cnt != size_of_head) {
    Log::Fatal("Binary file error: header is incorrect");
  }
  // get header
Guolin Ke's avatar
Guolin Ke committed
395
  const char* mem_ptr = buffer.data();
Guolin Ke's avatar
Guolin Ke committed
396
  dataset->num_data_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
397
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->num_data_));
Guolin Ke's avatar
Guolin Ke committed
398
  dataset->num_features_ = *(reinterpret_cast<const int*>(mem_ptr));
399
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->num_features_));
Guolin Ke's avatar
Guolin Ke committed
400
  dataset->num_total_features_ = *(reinterpret_cast<const int*>(mem_ptr));
401
402
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(dataset->num_total_features_));
Guolin Ke's avatar
Guolin Ke committed
403
  dataset->label_idx_ = *(reinterpret_cast<const int*>(mem_ptr));
404
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->label_idx_));
405
  dataset->max_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
406
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->max_bin_));
407
  dataset->bin_construct_sample_cnt_ = *(reinterpret_cast<const int*>(mem_ptr));
408
409
  mem_ptr += VirtualFileWriter::AlignedSize(
      sizeof(dataset->bin_construct_sample_cnt_));
410
  dataset->min_data_in_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
411
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->min_data_in_bin_));
412
  dataset->use_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
413
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->use_missing_));
414
  dataset->zero_as_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
415
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->zero_as_missing_));
416
417
  dataset->has_raw_ = *(reinterpret_cast<const bool*>(mem_ptr));
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->has_raw_));
Guolin Ke's avatar
Guolin Ke committed
418
419
  const int* tmp_feature_map = reinterpret_cast<const int*>(mem_ptr);
  dataset->used_feature_map_.clear();
Guolin Ke's avatar
Guolin Ke committed
420
  for (int i = 0; i < dataset->num_total_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
421
422
    dataset->used_feature_map_.push_back(tmp_feature_map[i]);
  }
423
424
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int) *
                                            dataset->num_total_features_);
Guolin Ke's avatar
Guolin Ke committed
425
426
  // num_groups
  dataset->num_groups_ = *(reinterpret_cast<const int*>(mem_ptr));
427
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->num_groups_));
Guolin Ke's avatar
Guolin Ke committed
428
429
430
431
432
433
  // real_feature_idx_
  const int* tmp_ptr_real_feature_idx_ = reinterpret_cast<const int*>(mem_ptr);
  dataset->real_feature_idx_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->real_feature_idx_.push_back(tmp_ptr_real_feature_idx_[i]);
  }
434
435
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * dataset->num_features_);
Guolin Ke's avatar
Guolin Ke committed
436
437
438
439
440
441
  // feature2group
  const int* tmp_ptr_feature2group = reinterpret_cast<const int*>(mem_ptr);
  dataset->feature2group_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->feature2group_.push_back(tmp_ptr_feature2group[i]);
  }
442
443
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * dataset->num_features_);
Guolin Ke's avatar
Guolin Ke committed
444
445
446
447
448
449
  // feature2subfeature
  const int* tmp_ptr_feature2subfeature = reinterpret_cast<const int*>(mem_ptr);
  dataset->feature2subfeature_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->feature2subfeature_.push_back(tmp_ptr_feature2subfeature[i]);
  }
450
451
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * dataset->num_features_);
Guolin Ke's avatar
Guolin Ke committed
452
453
454
455
456
457
458
459
460
461
462
  // group_bin_boundaries
  const uint64_t* tmp_ptr_group_bin_boundaries = reinterpret_cast<const uint64_t*>(mem_ptr);
  dataset->group_bin_boundaries_.clear();
  for (int i = 0; i < dataset->num_groups_ + 1; ++i) {
    dataset->group_bin_boundaries_.push_back(tmp_ptr_group_bin_boundaries[i]);
  }
  mem_ptr += sizeof(uint64_t) * (dataset->num_groups_ + 1);

  // group_feature_start_
  const int* tmp_ptr_group_feature_start = reinterpret_cast<const int*>(mem_ptr);
  dataset->group_feature_start_.clear();
463
  for (int i = 0; i < dataset->num_groups_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
464
465
    dataset->group_feature_start_.push_back(tmp_ptr_group_feature_start[i]);
  }
466
467
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * (dataset->num_groups_));
Guolin Ke's avatar
Guolin Ke committed
468
469
470
471
472
473
474

  // group_feature_cnt_
  const int* tmp_ptr_group_feature_cnt = reinterpret_cast<const int*>(mem_ptr);
  dataset->group_feature_cnt_.clear();
  for (int i = 0; i < dataset->num_groups_; ++i) {
    dataset->group_feature_cnt_.push_back(tmp_ptr_group_feature_cnt[i]);
  }
475
476
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * (dataset->num_groups_));
Guolin Ke's avatar
Guolin Ke committed
477

Belinda Trotta's avatar
Belinda Trotta committed
478
  if (!config_.max_bin_by_feature.empty()) {
479
480
    CHECK_EQ(static_cast<size_t>(dataset->num_total_features_), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
Belinda Trotta's avatar
Belinda Trotta committed
481
482
483
484
485
486
487
488
489
    dataset->max_bin_by_feature_.resize(dataset->num_total_features_);
    dataset->max_bin_by_feature_.assign(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end());
  } else {
    const int32_t* tmp_ptr_max_bin_by_feature = reinterpret_cast<const int32_t*>(mem_ptr);
    dataset->max_bin_by_feature_.clear();
    for (int i = 0; i < dataset->num_total_features_; ++i) {
      dataset->max_bin_by_feature_.push_back(tmp_ptr_max_bin_by_feature[i]);
    }
  }
490
491
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int32_t) *
                                            (dataset->num_total_features_));
Belinda Trotta's avatar
Belinda Trotta committed
492
493
494
495
  if (ArrayArgs<int32_t>::CheckAll(dataset->max_bin_by_feature_, -1)) {
    dataset->max_bin_by_feature_.clear();
  }

Guolin Ke's avatar
Guolin Ke committed
496
  // get feature names
Guolin Ke's avatar
Guolin Ke committed
497
  dataset->feature_names_.clear();
Guolin Ke's avatar
Guolin Ke committed
498
499
500
  // write feature names
  for (int i = 0; i < dataset->num_total_features_; ++i) {
    int str_len = *(reinterpret_cast<const int*>(mem_ptr));
501
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int));
Guolin Ke's avatar
Guolin Ke committed
502
    std::stringstream str_buf;
503
    auto tmp_arr = reinterpret_cast<const char*>(mem_ptr);
Guolin Ke's avatar
Guolin Ke committed
504
    for (int j = 0; j < str_len; ++j) {
505
      char tmp_char = tmp_arr[j];
Guolin Ke's avatar
Guolin Ke committed
506
507
      str_buf << tmp_char;
    }
508
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(char) * str_len);
Guolin Ke's avatar
Guolin Ke committed
509
    dataset->feature_names_.emplace_back(str_buf.str());
Guolin Ke's avatar
Guolin Ke committed
510
  }
511
512
513
514
  // get forced_bin_bounds_
  dataset->forced_bin_bounds_ = std::vector<std::vector<double>>(dataset->num_total_features_, std::vector<double>());
  for (int i = 0; i < dataset->num_total_features_; ++i) {
    int num_bounds = *(reinterpret_cast<const int*>(mem_ptr));
515
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int));
516
    dataset->forced_bin_bounds_[i] = std::vector<double>();
517
518
    const double* tmp_ptr_forced_bounds =
        reinterpret_cast<const double*>(mem_ptr);
519
520
521
522
523
524
    for (int j = 0; j < num_bounds; ++j) {
      double bound = tmp_ptr_forced_bounds[j];
      dataset->forced_bin_bounds_[i].push_back(bound);
    }
    mem_ptr += num_bounds * sizeof(double);
  }
Guolin Ke's avatar
Guolin Ke committed
525
526

  // read size of meta data
527
  read_cnt = reader->Read(buffer.data(), sizeof(size_t));
Guolin Ke's avatar
Guolin Ke committed
528

529
  if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
530
531
532
    Log::Fatal("Binary file error: meta data has the wrong size");
  }

Guolin Ke's avatar
Guolin Ke committed
533
  size_t size_of_metadata = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
534
535
536
537

  // re-allocate space if not enough
  if (size_of_metadata > buffer_size) {
    buffer_size = size_of_metadata;
Guolin Ke's avatar
Guolin Ke committed
538
    buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
539
540
  }
  //  read meta data
541
  read_cnt = reader->Read(buffer.data(), size_of_metadata);
Guolin Ke's avatar
Guolin Ke committed
542
543
544
545
546

  if (read_cnt != size_of_metadata) {
    Log::Fatal("Binary file error: meta data is incorrect");
  }
  // load meta data
Guolin Ke's avatar
Guolin Ke committed
547
  dataset->metadata_.LoadFromMemory(buffer.data());
Guolin Ke's avatar
Guolin Ke committed
548

549
550
  *num_global_data = dataset->num_data_;
  used_data_indices->clear();
Guolin Ke's avatar
Guolin Ke committed
551
  // sample local used data if need to partition
Guolin Ke's avatar
Guolin Ke committed
552
  if (num_machines > 1 && !config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
553
554
555
556
    const data_size_t* query_boundaries = dataset->metadata_.query_boundaries();
    if (query_boundaries == nullptr) {
      // if not contain query file, minimal sample unit is one record
      for (data_size_t i = 0; i < dataset->num_data_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
557
        if (random_.NextShort(0, num_machines) == rank) {
558
          used_data_indices->push_back(i);
Guolin Ke's avatar
Guolin Ke committed
559
560
561
562
563
564
565
566
567
        }
      }
    } else {
      // if contain query file, minimal sample unit is one query
      data_size_t num_queries = dataset->metadata_.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      for (data_size_t i = 0; i < dataset->num_data_; ++i) {
        if (qid >= num_queries) {
568
569
          Log::Fatal("Current query exceeds the range of the query file,\n"
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
570
571
572
573
        }
        if (i >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
574
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
575
576
577
578
579
            is_query_used = true;
          }
          ++qid;
        }
        if (is_query_used) {
580
          used_data_indices->push_back(i);
Guolin Ke's avatar
Guolin Ke committed
581
582
583
        }
      }
    }
584
    dataset->num_data_ = static_cast<data_size_t>((*used_data_indices).size());
Guolin Ke's avatar
Guolin Ke committed
585
  }
586
  dataset->metadata_.PartitionLabel(*used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
587
  // read feature data
Guolin Ke's avatar
Guolin Ke committed
588
  for (int i = 0; i < dataset->num_groups_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
589
    // read feature size
590
591
    read_cnt = reader->Read(buffer.data(), sizeof(size_t));
    if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
592
593
      Log::Fatal("Binary file error: feature %d has the wrong size", i);
    }
Guolin Ke's avatar
Guolin Ke committed
594
    size_t size_of_feature = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
595
596
597
    // re-allocate space if not enough
    if (size_of_feature > buffer_size) {
      buffer_size = size_of_feature;
Guolin Ke's avatar
Guolin Ke committed
598
      buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
599
600
    }

601
    read_cnt = reader->Read(buffer.data(), size_of_feature);
Guolin Ke's avatar
Guolin Ke committed
602
603
604
605

    if (read_cnt != size_of_feature) {
      Log::Fatal("Binary file error: feature %d is incorrect, read count: %d", i, read_cnt);
    }
Guolin Ke's avatar
Guolin Ke committed
606
    dataset->feature_groups_.emplace_back(std::unique_ptr<FeatureGroup>(
607
608
      new FeatureGroup(buffer.data(),
                       *num_global_data,
609
                       *used_data_indices, i)));
Guolin Ke's avatar
Guolin Ke committed
610
  }
Guolin Ke's avatar
Guolin Ke committed
611
  dataset->feature_groups_.shrink_to_fit();
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647

  // raw data
  dataset->numeric_feature_map_ = std::vector<int>(dataset->num_features_, false);
  dataset->num_numeric_features_ = 0;
  for (int i = 0; i < dataset->num_features_; ++i) {
    if (dataset->FeatureBinMapper(i)->bin_type() == BinType::CategoricalBin) {
      dataset->numeric_feature_map_[i] = -1;
    } else {
      dataset->numeric_feature_map_[i] = dataset->num_numeric_features_;
      ++dataset->num_numeric_features_;
    }
  }
  if (dataset->has_raw()) {
    dataset->ResizeRaw(dataset->num_data());
      size_t row_size = dataset->num_numeric_features_ * sizeof(float);
      if (row_size > buffer_size) {
        buffer_size = row_size;
        buffer.resize(buffer_size);
      }
    for (int i = 0; i < dataset->num_data(); ++i) {
      read_cnt = reader->Read(buffer.data(), row_size);
      if (read_cnt != row_size) {
        Log::Fatal("Binary file error: row %d of raw data is incorrect, read count: %d", i, read_cnt);
      }
      mem_ptr = buffer.data();
      const float* tmp_ptr_raw_row = reinterpret_cast<const float*>(mem_ptr);
      for (int j = 0; j < dataset->num_features(); ++j) {
        int feat_ind = dataset->numeric_feature_map_[j];
        if (feat_ind >= 0) {
          dataset->raw_data_[feat_ind][i] = tmp_ptr_raw_row[feat_ind];
        }
      }
      mem_ptr += row_size;
    }
  }

Guolin Ke's avatar
Guolin Ke committed
648
  dataset->is_finish_load_ = true;
Guolin Ke's avatar
Guolin Ke committed
649
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
650
651
}

652

653
654
655
Dataset* DatasetLoader::ConstructFromSampleData(double** sample_values,
                                                int** sample_indices, int num_col, const int* num_per_col,
                                                size_t total_sample_size, data_size_t num_data) {
656
  CheckSampleSize(total_sample_size, static_cast<size_t>(num_data));
657
658
659
660
661
  int num_total_features = num_col;
  if (Network::num_machines() > 1) {
    num_total_features = Network::GlobalSyncUpByMax(num_total_features);
  }
  std::vector<std::unique_ptr<BinMapper>> bin_mappers(num_total_features);
662
663
  // fill feature_names_ if not header
  if (feature_names_.empty()) {
664
    for (int i = 0; i < num_col; ++i) {
665
666
667
668
669
      std::stringstream str_buf;
      str_buf << "Column_" << i;
      feature_names_.push_back(str_buf.str());
    }
  }
Belinda Trotta's avatar
Belinda Trotta committed
670
  if (!config_.max_bin_by_feature.empty()) {
671
672
    CHECK_EQ(static_cast<size_t>(num_col), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
Belinda Trotta's avatar
Belinda Trotta committed
673
  }
674
675
676
677
678

  // get forced split
  std::string forced_bins_path = config_.forcedbins_filename;
  std::vector<std::vector<double>> forced_bin_bounds = DatasetLoader::GetForcedBins(forced_bins_path, num_col, categorical_features_);

Guolin Ke's avatar
Guolin Ke committed
679
  const data_size_t filter_cnt = static_cast<data_size_t>(
Guolin Ke's avatar
Guolin Ke committed
680
    static_cast<double>(config_.min_data_in_leaf * total_sample_size) / num_data);
681
682
683
684
685
686
687
688
689
690
691
692
693
  if (Network::num_machines() == 1) {
    // if only one machine, find bin locally
    OMP_INIT_EX();
    #pragma omp parallel for schedule(guided)
    for (int i = 0; i < num_col; ++i) {
      OMP_LOOP_EX_BEGIN();
      if (ignore_features_.count(i) > 0) {
        bin_mappers[i] = nullptr;
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(i)) {
        bin_type = BinType::CategoricalBin;
694
695
696
697
        bool feat_is_unconstrained = ((config_.monotone_constraints.size() == 0) || (config_.monotone_constraints[i] == 0));
        if (!feat_is_unconstrained) {
            Log::Fatal("The output cannot be monotone with respect to categorical features");
        }
698
699
      }
      bin_mappers[i].reset(new BinMapper());
Belinda Trotta's avatar
Belinda Trotta committed
700
701
      if (config_.max_bin_by_feature.empty()) {
        bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
702
                                config_.max_bin, config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter,
703
704
                                bin_type, config_.use_missing, config_.zero_as_missing,
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
705
706
      } else {
        bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
707
                                config_.max_bin_by_feature[i], config_.min_data_in_bin,
708
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
709
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
710
      }
711
712
713
714
715
716
717
718
719
720
721
722
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();
  } else {
    // if have multi-machines, need to find bin distributed
    // different machines will find bin for different features
    int num_machines = Network::num_machines();
    int rank = Network::rank();
    // start and len will store the process feature indices for different machines
    // machine i will find bins for features in [ start[i], start[i] + len[i] )
    std::vector<int> start(num_machines);
    std::vector<int> len(num_machines);
723
    int step = (num_total_features + num_machines - 1) / num_machines;
724
725
726
727
    if (step < 1) { step = 1; }

    start[0] = 0;
    for (int i = 0; i < num_machines - 1; ++i) {
728
      len[i] = std::min(step, num_total_features - start[i]);
729
730
      start[i + 1] = start[i] + len[i];
    }
731
    len[num_machines - 1] = num_total_features - start[num_machines - 1];
732
733
734
735
736
737
738
739
740
741
742
743
    OMP_INIT_EX();
    #pragma omp parallel for schedule(guided)
    for (int i = 0; i < len[rank]; ++i) {
      OMP_LOOP_EX_BEGIN();
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(start[rank] + i)) {
        bin_type = BinType::CategoricalBin;
      }
      bin_mappers[i].reset(new BinMapper());
744
745
746
      if (num_col <= start[rank] + i) {
        continue;
      }
Belinda Trotta's avatar
Belinda Trotta committed
747
      if (config_.max_bin_by_feature.empty()) {
748
749
        bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
                                total_sample_size, config_.max_bin, config_.min_data_in_bin,
750
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
751
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
752
      } else {
753
754
        bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
                                total_sample_size, config_.max_bin_by_feature[start[rank] + i],
755
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
756
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
757
      }
758
759
      OMP_LOOP_EX_END();
    }
Guolin Ke's avatar
Guolin Ke committed
760
    OMP_THROW_EX();
761
    comm_size_t self_buf_size = 0;
762
    for (int i = 0; i < len[rank]; ++i) {
763
764
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
765
      }
766
      self_buf_size += static_cast<comm_size_t>(bin_mappers[i]->SizesInByte());
Guolin Ke's avatar
Guolin Ke committed
767
    }
768
769
    std::vector<char> input_buffer(self_buf_size);
    auto cp_ptr = input_buffer.data();
770
771
772
773
    for (int i = 0; i < len[rank]; ++i) {
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
774
775
      bin_mappers[i]->CopyTo(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
776
777
778
      // free
      bin_mappers[i].reset(nullptr);
    }
779
780
781
782
    std::vector<comm_size_t> size_len = Network::GlobalArray(self_buf_size);
    std::vector<comm_size_t> size_start(num_machines, 0);
    for (int i = 1; i < num_machines; ++i) {
      size_start[i] = size_start[i - 1] + size_len[i - 1];
783
    }
784
785
    comm_size_t total_buffer_size = size_start[num_machines - 1] + size_len[num_machines - 1];
    std::vector<char> output_buffer(total_buffer_size);
786
    // gather global feature bin mappers
787
788
    Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), total_buffer_size);
    cp_ptr = output_buffer.data();
789
    // restore features bins from buffer
790
    for (int i = 0; i < num_total_features; ++i) {
791
792
793
794
795
      if (ignore_features_.count(i) > 0) {
        bin_mappers[i] = nullptr;
        continue;
      }
      bin_mappers[i].reset(new BinMapper());
796
797
      bin_mappers[i]->CopyFrom(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
798
    }
Guolin Ke's avatar
Guolin Ke committed
799
  }
Guolin Ke's avatar
Guolin Ke committed
800
  auto dataset = std::unique_ptr<Dataset>(new Dataset(num_data));
Guolin Ke's avatar
Guolin Ke committed
801
  dataset->Construct(&bin_mappers, num_total_features, forced_bin_bounds, sample_indices, sample_values, num_per_col, num_col, total_sample_size, config_);
802
803
804
  if (dataset->has_raw()) {
    dataset->ResizeRaw(num_data);
  }
805
  dataset->set_feature_names(feature_names_);
Guolin Ke's avatar
Guolin Ke committed
806
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
807
}
Guolin Ke's avatar
Guolin Ke committed
808
809
810
811


// ---- private functions ----

812
void DatasetLoader::CheckDataset(const Dataset* dataset, bool is_load_from_binary) {
Guolin Ke's avatar
Guolin Ke committed
813
  if (dataset->num_data_ <= 0) {
Guolin Ke's avatar
Guolin Ke committed
814
    Log::Fatal("Data file %s is empty", dataset->data_filename_.c_str());
Guolin Ke's avatar
Guolin Ke committed
815
  }
816
817
  if (dataset->feature_names_.size() != static_cast<size_t>(dataset->num_total_features_)) {
    Log::Fatal("Size of feature name error, should be %d, got %d", dataset->num_total_features_,
818
               static_cast<int>(dataset->feature_names_.size()));
819
  }
Guolin Ke's avatar
Guolin Ke committed
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
  bool is_feature_order_by_group = true;
  int last_group = -1;
  int last_sub_feature = -1;
  // if features are ordered, not need to use hist_buf
  for (int i = 0; i < dataset->num_features_; ++i) {
    int group = dataset->feature2group_[i];
    int sub_feature = dataset->feature2subfeature_[i];
    if (group < last_group) {
      is_feature_order_by_group = false;
    } else if (group == last_group) {
      if (sub_feature <= last_sub_feature) {
        is_feature_order_by_group = false;
        break;
      }
    }
    last_group = group;
    last_sub_feature = sub_feature;
  }
  if (!is_feature_order_by_group) {
839
    Log::Fatal("Features in dataset should be ordered by group");
Guolin Ke's avatar
Guolin Ke committed
840
  }
841
842
843

  if (is_load_from_binary) {
    if (dataset->max_bin_ != config_.max_bin) {
844
845
      Log::Fatal("Dataset was constructed with parameter max_bin=%d. It cannot be changed to %d when loading from binary file.",
                 dataset->max_bin_, config_.max_bin);
846
847
    }
    if (dataset->min_data_in_bin_ != config_.min_data_in_bin) {
848
849
      Log::Fatal("Dataset was constructed with parameter min_data_in_bin=%d. It cannot be changed to %d when loading from binary file.",
                 dataset->min_data_in_bin_, config_.min_data_in_bin);
850
851
    }
    if (dataset->use_missing_ != config_.use_missing) {
852
853
      Log::Fatal("Dataset was constructed with parameter use_missing=%d. It cannot be changed to %d when loading from binary file.",
                 dataset->use_missing_, config_.use_missing);
854
855
    }
    if (dataset->zero_as_missing_ != config_.zero_as_missing) {
856
857
      Log::Fatal("Dataset was constructed with parameter zero_as_missing=%d. It cannot be changed to %d when loading from binary file.",
                 dataset->zero_as_missing_, config_.zero_as_missing);
858
859
    }
    if (dataset->bin_construct_sample_cnt_ != config_.bin_construct_sample_cnt) {
860
861
      Log::Fatal("Dataset was constructed with parameter bin_construct_sample_cnt=%d. It cannot be changed to %d when loading from binary file.",
                 dataset->bin_construct_sample_cnt_, config_.bin_construct_sample_cnt);
862
863
864
865
    }
    if ((dataset->max_bin_by_feature_.size() != config_.max_bin_by_feature.size()) ||
        !std::equal(dataset->max_bin_by_feature_.begin(), dataset->max_bin_by_feature_.end(),
            config_.max_bin_by_feature.begin())) {
866
      Log::Fatal("Parameter max_bin_by_feature cannot be changed when loading from binary file.");
867
868
    }

869
    if (config_.label_column != "") {
870
      Log::Warning("Parameter label_column works only in case of loading data directly from text file. It will be ignored when loading from binary file.");
871
872
    }
    if (config_.weight_column != "") {
873
      Log::Warning("Parameter weight_column works only in case of loading data directly from text file. It will be ignored when loading from binary file.");
874
875
    }
    if (config_.group_column != "") {
876
      Log::Warning("Parameter group_column works only in case of loading data directly from text file. It will be ignored when loading from binary file.");
877
878
    }
    if (config_.ignore_column != "") {
879
      Log::Warning("Parameter ignore_column works only in case of loading data directly from text file. It will be ignored when loading from binary file.");
880
    }
881
    if (config_.two_round) {
882
      Log::Warning("Parameter two_round works only in case of loading data directly from text file. It will be ignored when loading from binary file.");
883
884
    }
    if (config_.header) {
885
      Log::Warning("Parameter header works only in case of loading data directly from text file. It will be ignored when loading from binary file.");
886
    }
887
  }
Guolin Ke's avatar
Guolin Ke committed
888
889
890
}

std::vector<std::string> DatasetLoader::LoadTextDataToMemory(const char* filename, const Metadata& metadata,
891
892
                                                             int rank, int num_machines, int* num_global_data,
                                                             std::vector<data_size_t>* used_data_indices) {
893
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
894
  used_data_indices->clear();
Guolin Ke's avatar
Guolin Ke committed
895
  if (num_machines == 1 || config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
896
897
898
899
900
901
902
903
904
    // read all lines
    *num_global_data = text_reader.ReadAllLines();
  } else {  // need partition data
            // get query data
    const data_size_t* query_boundaries = metadata.query_boundaries();

    if (query_boundaries == nullptr) {
      // if not contain query data, minimal sample unit is one record
      *num_global_data = text_reader.ReadAndFilterLines([this, rank, num_machines](data_size_t) {
Guolin Ke's avatar
Guolin Ke committed
905
        if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
906
907
908
909
910
911
912
913
914
915
916
917
918
919
          return true;
        } else {
          return false;
        }
      }, used_data_indices);
    } else {
      // if contain query data, minimal sample unit is one query
      data_size_t num_queries = metadata.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      *num_global_data = text_reader.ReadAndFilterLines(
        [this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries]
      (data_size_t line_idx) {
        if (qid >= num_queries) {
920
921
          Log::Fatal("Current query exceeds the range of the query file,\n"
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
922
923
924
925
        }
        if (line_idx >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
926
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
927
928
929
930
931
932
933
934
935
936
937
938
            is_query_used = true;
          }
          ++qid;
        }
        return is_query_used;
      }, used_data_indices);
    }
  }
  return std::move(text_reader.Lines());
}

std::vector<std::string> DatasetLoader::SampleTextDataFromMemory(const std::vector<std::string>& data) {
Guolin Ke's avatar
Guolin Ke committed
939
  int sample_cnt = config_.bin_construct_sample_cnt;
940
941
  if (static_cast<size_t>(sample_cnt) > data.size()) {
    sample_cnt = static_cast<int>(data.size());
942
  }
943
  auto sample_indices = random_.Sample(static_cast<int>(data.size()), sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
944
  std::vector<std::string> out(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
945
946
  for (size_t i = 0; i < sample_indices.size(); ++i) {
    const size_t idx = sample_indices[i];
Guolin Ke's avatar
Guolin Ke committed
947
    out[i] = data[idx];
Guolin Ke's avatar
Guolin Ke committed
948
949
950
951
  }
  return out;
}

952
953
954
std::vector<std::string> DatasetLoader::SampleTextDataFromFile(const char* filename, const Metadata& metadata,
                                                               int rank, int num_machines, int* num_global_data,
                                                               std::vector<data_size_t>* used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
955
  const data_size_t sample_cnt = static_cast<data_size_t>(config_.bin_construct_sample_cnt);
956
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
957
  std::vector<std::string> out_data;
Guolin Ke's avatar
Guolin Ke committed
958
  if (num_machines == 1 || config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
959
    *num_global_data = static_cast<data_size_t>(text_reader.SampleFromFile(&random_, sample_cnt, &out_data));
Guolin Ke's avatar
Guolin Ke committed
960
961
962
963
964
965
966
  } else {  // need partition data
            // get query data
    const data_size_t* query_boundaries = metadata.query_boundaries();
    if (query_boundaries == nullptr) {
      // if not contain query file, minimal sample unit is one record
      *num_global_data = text_reader.SampleAndFilterFromFile([this, rank, num_machines]
      (data_size_t) {
Guolin Ke's avatar
Guolin Ke committed
967
        if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
968
969
970
971
          return true;
        } else {
          return false;
        }
Guolin Ke's avatar
Guolin Ke committed
972
      }, used_data_indices, &random_, sample_cnt, &out_data);
Guolin Ke's avatar
Guolin Ke committed
973
974
975
976
977
978
979
980
981
    } else {
      // if contain query file, minimal sample unit is one query
      data_size_t num_queries = metadata.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      *num_global_data = text_reader.SampleAndFilterFromFile(
        [this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries]
      (data_size_t line_idx) {
        if (qid >= num_queries) {
982
983
          Log::Fatal("Query id exceeds the range of the query file, "
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
984
985
986
987
        }
        if (line_idx >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
988
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
989
990
991
992
993
            is_query_used = true;
          }
          ++qid;
        }
        return is_query_used;
Guolin Ke's avatar
Guolin Ke committed
994
      }, used_data_indices, &random_, sample_cnt, &out_data);
Guolin Ke's avatar
Guolin Ke committed
995
996
997
998
999
    }
  }
  return out_data;
}

1000
1001
1002
void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
                                                    const std::vector<std::string>& sample_data,
                                                    const Parser* parser, Dataset* dataset) {
1003
  auto t1 = std::chrono::high_resolution_clock::now();
Guolin Ke's avatar
Guolin Ke committed
1004
  std::vector<std::vector<double>> sample_values;
Guolin Ke's avatar
Guolin Ke committed
1005
  std::vector<std::vector<int>> sample_indices;
Guolin Ke's avatar
Guolin Ke committed
1006
1007
  std::vector<std::pair<int, double>> oneline_features;
  double label;
Guolin Ke's avatar
Guolin Ke committed
1008
  for (int i = 0; i < static_cast<int>(sample_data.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
1009
1010
1011
1012
    oneline_features.clear();
    // parse features
    parser->ParseOneLine(sample_data[i].c_str(), &oneline_features, &label);
    for (std::pair<int, double>& inner_data : oneline_features) {
1013
      if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
Guolin Ke's avatar
Guolin Ke committed
1014
1015
        sample_values.resize(inner_data.first + 1);
        sample_indices.resize(inner_data.first + 1);
1016
      }
Guolin Ke's avatar
Guolin Ke committed
1017
      if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
Guolin Ke's avatar
Guolin Ke committed
1018
1019
        sample_values[inner_data.first].emplace_back(inner_data.second);
        sample_indices[inner_data.first].emplace_back(i);
Guolin Ke's avatar
Guolin Ke committed
1020
1021
1022
1023
      }
    }
  }

Guolin Ke's avatar
Guolin Ke committed
1024
  dataset->feature_groups_.clear();
1025
1026
1027
1028
1029
  dataset->num_total_features_ = std::max(static_cast<int>(sample_values.size()), parser->NumFeatures());
  if (num_machines > 1) {
    dataset->num_total_features_ = Network::GlobalSyncUpByMax(dataset->num_total_features_);
  }
  if (!feature_names_.empty()) {
1030
    CHECK_EQ(dataset->num_total_features_, static_cast<int>(feature_names_.size()));
1031
  }
Guolin Ke's avatar
Guolin Ke committed
1032

Belinda Trotta's avatar
Belinda Trotta committed
1033
  if (!config_.max_bin_by_feature.empty()) {
1034
1035
    CHECK_EQ(static_cast<size_t>(dataset->num_total_features_), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
Belinda Trotta's avatar
Belinda Trotta committed
1036
1037
  }

1038
1039
  // get forced split
  std::string forced_bins_path = config_.forcedbins_filename;
1040
1041
  std::vector<std::vector<double>> forced_bin_bounds = DatasetLoader::GetForcedBins(forced_bins_path,
                                                                                    dataset->num_total_features_,
1042
1043
                                                                                    categorical_features_);

Guolin Ke's avatar
Guolin Ke committed
1044
  // check the range of label_idx, weight_idx and group_idx
1045
1046
1047
1048
1049
  // skip label check if user input parser config file,
  // because label id is got from raw features while dataset features are consistent with customized parser.
  if (dataset->parser_config_str_.empty()) {
    CHECK(label_idx_ >= 0 && label_idx_ <= dataset->num_total_features_);
  }
Guolin Ke's avatar
Guolin Ke committed
1050
1051
1052
1053
  CHECK(weight_idx_ < 0 || weight_idx_ < dataset->num_total_features_);
  CHECK(group_idx_ < 0 || group_idx_ < dataset->num_total_features_);

  // fill feature_names_ if not header
Guolin Ke's avatar
Guolin Ke committed
1054
  if (feature_names_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1055
1056
1057
1058
1059
1060
    for (int i = 0; i < dataset->num_total_features_; ++i) {
      std::stringstream str_buf;
      str_buf << "Column_" << i;
      feature_names_.push_back(str_buf.str());
    }
  }
1061
  dataset->set_feature_names(feature_names_);
Guolin Ke's avatar
Guolin Ke committed
1062
  std::vector<std::unique_ptr<BinMapper>> bin_mappers(dataset->num_total_features_);
Guolin Ke's avatar
Guolin Ke committed
1063
  const data_size_t filter_cnt = static_cast<data_size_t>(
Guolin Ke's avatar
Guolin Ke committed
1064
    static_cast<double>(config_.min_data_in_leaf* sample_data.size()) / dataset->num_data_);
Guolin Ke's avatar
Guolin Ke committed
1065
1066
1067
  // start find bins
  if (num_machines == 1) {
    // if only one machine, find bin locally
1068
    OMP_INIT_EX();
1069
    #pragma omp parallel for schedule(guided)
Guolin Ke's avatar
Guolin Ke committed
1070
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
1071
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1072
      if (ignore_features_.count(i) > 0) {
Guolin Ke's avatar
Guolin Ke committed
1073
        bin_mappers[i] = nullptr;
Guolin Ke's avatar
Guolin Ke committed
1074
1075
        continue;
      }
1076
1077
1078
1079
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(i)) {
        bin_type = BinType::CategoricalBin;
      }
Guolin Ke's avatar
Guolin Ke committed
1080
      bin_mappers[i].reset(new BinMapper());
Belinda Trotta's avatar
Belinda Trotta committed
1081
1082
      if (config_.max_bin_by_feature.empty()) {
        bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
1083
                                sample_data.size(), config_.max_bin, config_.min_data_in_bin,
1084
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
1085
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
1086
1087
      } else {
        bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
1088
                                sample_data.size(), config_.max_bin_by_feature[i],
1089
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
1090
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
1091
      }
1092
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1093
    }
1094
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1095
1096
  } else {
    // start and len will store the process feature indices for different machines
1097
    // machine i will find bins for features in [ start[i], start[i] + len[i] )
Guolin Ke's avatar
Guolin Ke committed
1098
1099
    std::vector<int> start(num_machines);
    std::vector<int> len(num_machines);
1100
    int step = (dataset->num_total_features_ + num_machines - 1) / num_machines;
Guolin Ke's avatar
Guolin Ke committed
1101
1102
1103
1104
    if (step < 1) { step = 1; }

    start[0] = 0;
    for (int i = 0; i < num_machines - 1; ++i) {
1105
      len[i] = std::min(step, dataset->num_total_features_ - start[i]);
Guolin Ke's avatar
Guolin Ke committed
1106
1107
      start[i + 1] = start[i] + len[i];
    }
1108
    len[num_machines - 1] = dataset->num_total_features_ - start[num_machines - 1];
1109
    OMP_INIT_EX();
1110
    #pragma omp parallel for schedule(guided)
1111
    for (int i = 0; i < len[rank]; ++i) {
1112
      OMP_LOOP_EX_BEGIN();
1113
1114
1115
1116
1117
1118
1119
1120
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(start[rank] + i)) {
        bin_type = BinType::CategoricalBin;
      }
      bin_mappers[i].reset(new BinMapper());
Nikita Titov's avatar
Nikita Titov committed
1121
      if (static_cast<int>(sample_values.size()) <= start[rank] + i) {
1122
1123
        continue;
      }
Belinda Trotta's avatar
Belinda Trotta committed
1124
      if (config_.max_bin_by_feature.empty()) {
1125
        bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
Belinda Trotta's avatar
Belinda Trotta committed
1126
                                static_cast<int>(sample_values[start[rank] + i].size()),
1127
                                sample_data.size(), config_.max_bin, config_.min_data_in_bin,
1128
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
1129
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
1130
      } else {
1131
        bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
Belinda Trotta's avatar
Belinda Trotta committed
1132
                                static_cast<int>(sample_values[start[rank] + i].size()),
1133
                                sample_data.size(), config_.max_bin_by_feature[i],
1134
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type,
1135
                                config_.use_missing, config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
1136
      }
1137
      OMP_LOOP_EX_END();
1138
    }
1139
    OMP_THROW_EX();
1140
    comm_size_t self_buf_size = 0;
Guolin Ke's avatar
Guolin Ke committed
1141
    for (int i = 0; i < len[rank]; ++i) {
1142
1143
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
Guolin Ke's avatar
Guolin Ke committed
1144
      }
1145
      self_buf_size += static_cast<comm_size_t>(bin_mappers[i]->SizesInByte());
Guolin Ke's avatar
Guolin Ke committed
1146
    }
1147
1148
    std::vector<char> input_buffer(self_buf_size);
    auto cp_ptr = input_buffer.data();
Guolin Ke's avatar
Guolin Ke committed
1149
    for (int i = 0; i < len[rank]; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1150
1151
1152
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
1153
1154
      bin_mappers[i]->CopyTo(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
1155
1156
      // free
      bin_mappers[i].reset(nullptr);
Guolin Ke's avatar
Guolin Ke committed
1157
    }
1158
1159
1160
1161
    std::vector<comm_size_t> size_len = Network::GlobalArray(self_buf_size);
    std::vector<comm_size_t> size_start(num_machines, 0);
    for (int i = 1; i < num_machines; ++i) {
      size_start[i] = size_start[i - 1] + size_len[i - 1];
Guolin Ke's avatar
Guolin Ke committed
1162
    }
1163
1164
    comm_size_t total_buffer_size = size_start[num_machines - 1] + size_len[num_machines - 1];
    std::vector<char> output_buffer(total_buffer_size);
Guolin Ke's avatar
Guolin Ke committed
1165
    // gather global feature bin mappers
1166
1167
    Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), total_buffer_size);
    cp_ptr = output_buffer.data();
Guolin Ke's avatar
Guolin Ke committed
1168
    // restore features bins from buffer
1169
    for (int i = 0; i < dataset->num_total_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1170
      if (ignore_features_.count(i) > 0) {
Guolin Ke's avatar
Guolin Ke committed
1171
        bin_mappers[i] = nullptr;
Guolin Ke's avatar
Guolin Ke committed
1172
1173
        continue;
      }
Guolin Ke's avatar
Guolin Ke committed
1174
      bin_mappers[i].reset(new BinMapper());
1175
1176
      bin_mappers[i]->CopyFrom(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
Guolin Ke's avatar
Guolin Ke committed
1177
1178
    }
  }
1179
  dataset->Construct(&bin_mappers, dataset->num_total_features_, forced_bin_bounds, Common::Vector2Ptr<int>(&sample_indices).data(),
Guolin Ke's avatar
Guolin Ke committed
1180
                     Common::Vector2Ptr<double>(&sample_values).data(),
1181
                     Common::VectorSize<int>(sample_indices).data(), static_cast<int>(sample_indices.size()), sample_data.size(), config_);
1182
  if (dataset->has_raw()) {
1183
    dataset->ResizeRaw(static_cast<int>(sample_data.size()));
1184
  }
1185
1186
1187
1188

  auto t2 = std::chrono::high_resolution_clock::now();
  Log::Info("Construct bin mappers from text data time %.2f seconds",
            std::chrono::duration<double, std::milli>(t2 - t1) * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
1189
1190
1191
}

/*! \brief Extract local features from memory */
Guolin Ke's avatar
Guolin Ke committed
1192
void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>* text_data, const Parser* parser, Dataset* dataset) {
Guolin Ke's avatar
Guolin Ke committed
1193
1194
  std::vector<std::pair<int, double>> oneline_features;
  double tmp_label = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
1195
  auto& ref_text_data = *text_data;
1196
  std::vector<float> feature_row(dataset->num_features_);
1197
  if (!predict_fun_) {
1198
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1199
    // if doesn't need to prediction with initial model
1200
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label, feature_row)
Guolin Ke's avatar
Guolin Ke committed
1201
    for (data_size_t i = 0; i < dataset->num_data_; ++i) {
1202
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1203
1204
1205
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
Guolin Ke's avatar
Guolin Ke committed
1206
      parser->ParseOneLine(ref_text_data[i].c_str(), &oneline_features, &tmp_label);
Guolin Ke's avatar
Guolin Ke committed
1207
      // set label
1208
      dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1209
      // free processed line:
Guolin Ke's avatar
Guolin Ke committed
1210
      ref_text_data[i].clear();
Guolin Ke's avatar
Guolin Ke committed
1211
1212
      // shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
      // text_reader_->Lines()[i].shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
1213
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1214
1215
      // push data
      for (auto& inner_data : oneline_features) {
1216
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1217
1218
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1219
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1220
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1221
1222
1223
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
          dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
1224
          if (dataset->has_raw()) {
1225
            feature_row[feature_idx] = static_cast<float>(inner_data.second);
1226
          }
Guolin Ke's avatar
Guolin Ke committed
1227
1228
        } else {
          if (inner_data.first == weight_idx_) {
1229
            dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1230
1231
1232
1233
1234
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
1235
1236
1237
1238
1239
1240
1241
1242
      if (dataset->has_raw()) {
        for (size_t j = 0; j < feature_row.size(); ++j) {
          int feat_ind = dataset->numeric_feature_map_[j];
          if (feat_ind >= 0) {
            dataset->raw_data_[feat_ind][i] = feature_row[j];
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1243
      dataset->FinishOneRow(tid, i, is_feature_added);
1244
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1245
    }
1246
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1247
  } else {
1248
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1249
    // if need to prediction with initial model
1250
    std::vector<double> init_score(dataset->num_data_ * num_class_);
1251
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label, feature_row)
Guolin Ke's avatar
Guolin Ke committed
1252
    for (data_size_t i = 0; i < dataset->num_data_; ++i) {
1253
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1254
1255
1256
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
Guolin Ke's avatar
Guolin Ke committed
1257
      parser->ParseOneLine(ref_text_data[i].c_str(), &oneline_features, &tmp_label);
Guolin Ke's avatar
Guolin Ke committed
1258
      // set initial score
Guolin Ke's avatar
Guolin Ke committed
1259
1260
      std::vector<double> oneline_init_score(num_class_);
      predict_fun_(oneline_features, oneline_init_score.data());
1261
      for (int k = 0; k < num_class_; ++k) {
1262
        init_score[k * dataset->num_data_ + i] = static_cast<double>(oneline_init_score[k]);
Guolin Ke's avatar
Guolin Ke committed
1263
1264
      }
      // set label
1265
      dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1266
      // free processed line:
1267
      ref_text_data[i].clear();
Andrew Ziem's avatar
Andrew Ziem committed
1268
      // shrink_to_fit will be very slow in Linux, and seems not free memory, disable for now
Guolin Ke's avatar
Guolin Ke committed
1269
1270
      // text_reader_->Lines()[i].shrink_to_fit();
      // push data
Guolin Ke's avatar
Guolin Ke committed
1271
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1272
      for (auto& inner_data : oneline_features) {
1273
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1274
1275
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1276
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1277
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1278
1279
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
1280
          dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
1281
          if (dataset->has_raw()) {
1282
            feature_row[feature_idx] = static_cast<float>(inner_data.second);
1283
          }
Guolin Ke's avatar
Guolin Ke committed
1284
1285
        } else {
          if (inner_data.first == weight_idx_) {
1286
            dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1287
1288
1289
1290
1291
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1292
      dataset->FinishOneRow(tid, i, is_feature_added);
1293
1294
1295
1296
1297
1298
1299
1300
      if (dataset->has_raw()) {
        for (size_t j = 0; j < feature_row.size(); ++j) {
          int feat_ind = dataset->numeric_feature_map_[j];
          if (feat_ind >= 0) {
            dataset->raw_data_[feat_ind][i] = feature_row[j];
          }
        }
      }
1301
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1302
    }
1303
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1304
    // metadata_ will manage space of init_score
1305
    dataset->metadata_.SetInitScore(init_score.data(), dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1306
  }
Guolin Ke's avatar
Guolin Ke committed
1307
  dataset->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1308
  // text data can be free after loaded feature values
Guolin Ke's avatar
Guolin Ke committed
1309
  text_data->clear();
Guolin Ke's avatar
Guolin Ke committed
1310
1311
1312
}

/*! \brief Extract local features from file */
1313
1314
void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser* parser,
                                            const std::vector<data_size_t>& used_data_indices, Dataset* dataset) {
1315
  std::vector<double> init_score;
1316
  if (predict_fun_) {
1317
    init_score = std::vector<double>(dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1318
1319
1320
1321
1322
1323
  }
  std::function<void(data_size_t, const std::vector<std::string>&)> process_fun =
    [this, &init_score, &parser, &dataset]
  (data_size_t start_idx, const std::vector<std::string>& lines) {
    std::vector<std::pair<int, double>> oneline_features;
    double tmp_label = 0.0f;
1324
    std::vector<float> feature_row(dataset->num_features_);
1325
    OMP_INIT_EX();
1326
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label, feature_row)
Guolin Ke's avatar
Guolin Ke committed
1327
    for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) {
1328
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1329
1330
1331
1332
1333
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
      parser->ParseOneLine(lines[i].c_str(), &oneline_features, &tmp_label);
      // set initial score
Guolin Ke's avatar
Guolin Ke committed
1334
      if (!init_score.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1335
1336
        std::vector<double> oneline_init_score(num_class_);
        predict_fun_(oneline_features, oneline_init_score.data());
1337
        for (int k = 0; k < num_class_; ++k) {
1338
          init_score[k * dataset->num_data_ + start_idx + i] = static_cast<double>(oneline_init_score[k]);
Guolin Ke's avatar
Guolin Ke committed
1339
1340
1341
        }
      }
      // set label
1342
      dataset->metadata_.SetLabelAt(start_idx + i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1343
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1344
1345
      // push data
      for (auto& inner_data : oneline_features) {
1346
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1347
1348
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1349
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1350
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1351
1352
1353
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
          dataset->feature_groups_[group]->PushData(tid, sub_feature, start_idx + i, inner_data.second);
1354
          if (dataset->has_raw()) {
1355
            feature_row[feature_idx] = static_cast<float>(inner_data.second);
1356
          }
Guolin Ke's avatar
Guolin Ke committed
1357
1358
        } else {
          if (inner_data.first == weight_idx_) {
1359
            dataset->metadata_.SetWeightAt(start_idx + i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1360
1361
1362
1363
1364
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(start_idx + i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
1365
1366
1367
1368
1369
1370
1371
1372
      if (dataset->has_raw()) {
        for (size_t j = 0; j < feature_row.size(); ++j) {
          int feat_ind = dataset->numeric_feature_map_[j];
          if (feat_ind >= 0) {
            dataset->raw_data_[feat_ind][i] = feature_row[j];
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1373
      dataset->FinishOneRow(tid, i, is_feature_added);
1374
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1375
    }
1376
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1377
  };
1378
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
1379
  if (!used_data_indices.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1380
1381
1382
1383
1384
1385
1386
1387
    // only need part of data
    text_reader.ReadPartAndProcessParallel(used_data_indices, process_fun);
  } else {
    // need full data
    text_reader.ReadAllAndProcessParallel(process_fun);
  }

  // metadata_ will manage space of init_score
Guolin Ke's avatar
Guolin Ke committed
1388
  if (!init_score.empty()) {
1389
    dataset->metadata_.SetInitScore(init_score.data(), dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1390
  }
Guolin Ke's avatar
Guolin Ke committed
1391
  dataset->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1392
1393
1394
}

/*! \brief Check can load from binary file */
1395
std::string DatasetLoader::CheckCanLoadFromBin(const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
1396
1397
1398
  std::string bin_filename(filename);
  bin_filename.append(".bin");

1399
  auto reader = VirtualFileReader::Make(bin_filename.c_str());
Guolin Ke's avatar
Guolin Ke committed
1400

1401
  if (!reader->Init()) {
1402
    bin_filename = std::string(filename);
1403
1404
    reader = VirtualFileReader::Make(bin_filename.c_str());
    if (!reader->Init()) {
1405
      Log::Fatal("Cannot open data file %s", bin_filename.c_str());
1406
    }
1407
  }
1408
1409
1410
1411
1412

  size_t buffer_size = 256;
  auto buffer = std::vector<char>(buffer_size);
  // read size of token
  size_t size_of_token = std::strlen(Dataset::binary_file_token);
1413
  size_t read_cnt = reader->Read(buffer.data(), size_of_token);
1414
1415
  if (read_cnt == size_of_token
      && std::string(buffer.data()) == std::string(Dataset::binary_file_token)) {
1416
    return bin_filename;
Guolin Ke's avatar
Guolin Ke committed
1417
  } else {
1418
    return std::string();
Guolin Ke's avatar
Guolin Ke committed
1419
1420
1421
  }
}

1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
std::vector<std::vector<double>> DatasetLoader::GetForcedBins(std::string forced_bins_path, int num_total_features,
                                                              const std::unordered_set<int>& categorical_features) {
  std::vector<std::vector<double>> forced_bins(num_total_features, std::vector<double>());
  if (forced_bins_path != "") {
    std::ifstream forced_bins_stream(forced_bins_path.c_str());
    if (forced_bins_stream.fail()) {
      Log::Warning("Could not open %s. Will ignore.", forced_bins_path.c_str());
    } else {
      std::stringstream buffer;
      buffer << forced_bins_stream.rdbuf();
      std::string err;
Guolin Ke's avatar
Guolin Ke committed
1433
      Json forced_bins_json = Json::parse(buffer.str(), &err);
1434
1435
1436
1437
      CHECK(forced_bins_json.is_array());
      std::vector<Json> forced_bins_arr = forced_bins_json.array_items();
      for (size_t i = 0; i < forced_bins_arr.size(); ++i) {
        int feature_num = forced_bins_arr[i]["feature"].int_value();
Nikita Titov's avatar
Nikita Titov committed
1438
        CHECK_LT(feature_num, num_total_features);
1439
        if (categorical_features.count(feature_num)) {
1440
          Log::Warning("Feature %d is categorical. Will ignore forced bins for this feature.", feature_num);
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
        } else {
          std::vector<Json> bounds_arr = forced_bins_arr[i]["bin_upper_bound"].array_items();
          for (size_t j = 0; j < bounds_arr.size(); ++j) {
            forced_bins[feature_num].push_back(bounds_arr[j].number_value());
          }
        }
      }
      // remove duplicates
      for (int i = 0; i < num_total_features; ++i) {
        auto new_end = std::unique(forced_bins[i].begin(), forced_bins[i].end());
        forced_bins[i].erase(new_end, forced_bins[i].end());
      }
    }
  }
  return forced_bins;
}

1458
}  // namespace LightGBM