dataset_loader.cpp 68 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
#include <LightGBM/dataset_loader.h>

Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/network.h>
8
#include <LightGBM/utils/array_args.h>
9
#include <LightGBM/utils/json11.h>
10
11
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
12

13
#include <algorithm>
14
#include <chrono>
15
#include <fstream>
16
17
18
19
20
21
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
22

Guolin Ke's avatar
Guolin Ke committed
23
24
namespace LightGBM {

25
using json11_internal_lightgbm::Json;
26

Guolin Ke's avatar
Guolin Ke committed
27
28
DatasetLoader::DatasetLoader(const Config& io_config, const PredictFunction& predict_fun, int num_class, const char* filename)
  :config_(io_config), random_(config_.data_random_seed), predict_fun_(predict_fun), num_class_(num_class) {
Guolin Ke's avatar
Guolin Ke committed
29
30
31
32
  label_idx_ = 0;
  weight_idx_ = NO_SPECIFIC;
  group_idx_ = NO_SPECIFIC;
  SetHeader(filename);
33
34
35
36
  store_raw_ = false;
  if (io_config.linear_tree) {
    store_raw_ = true;
  }
Guolin Ke's avatar
Guolin Ke committed
37
38
39
40
41
}

DatasetLoader::~DatasetLoader() {
}

Guolin Ke's avatar
Guolin Ke committed
42
void DatasetLoader::SetHeader(const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
43
  std::unordered_map<std::string, int> name2idx;
Guolin Ke's avatar
Guolin Ke committed
44
  std::string name_prefix("name:");
45
  if (filename != nullptr && CheckCanLoadFromBin(filename) == "") {
Guolin Ke's avatar
Guolin Ke committed
46
    TextReader<data_size_t> text_reader(filename, config_.header);
Guolin Ke's avatar
Guolin Ke committed
47

Guolin Ke's avatar
Guolin Ke committed
48
    // get column names
Guolin Ke's avatar
Guolin Ke committed
49
    if (config_.header) {
Guolin Ke's avatar
Guolin Ke committed
50
      std::string first_line = text_reader.first_line();
51
      feature_names_ = Common::Split(first_line.c_str(), "\t,");
52
53
54
55
56
57
58
59
60
61
62
63
    } else if (!config_.parser_config_file.empty()) {
      // support to get header from parser config, so could utilize following label name to id mapping logic.
      TextReader<data_size_t> parser_config_reader(config_.parser_config_file.c_str(), false);
      parser_config_reader.ReadAllLines();
      std::string parser_config_str = parser_config_reader.JoinedLines();
      if (!parser_config_str.empty()) {
        std::string header_in_parser_config = Common::GetFromParserConfig(parser_config_str, "header");
        if (!header_in_parser_config.empty()) {
          Log::Info("Get raw column names from parser config.");
          feature_names_ = Common::Split(header_in_parser_config.c_str(), "\t,");
        }
      }
Guolin Ke's avatar
Guolin Ke committed
64
65
    }

Guolin Ke's avatar
Guolin Ke committed
66
    // load label idx first
Guolin Ke's avatar
Guolin Ke committed
67
68
69
    if (config_.label_column.size() > 0) {
      if (Common::StartsWith(config_.label_column, name_prefix)) {
        std::string name = config_.label_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
70
71
72
73
74
75
76
77
78
79
        label_idx_ = -1;
        for (int i = 0; i < static_cast<int>(feature_names_.size()); ++i) {
          if (name == feature_names_[i]) {
            label_idx_ = i;
            break;
          }
        }
        if (label_idx_ >= 0) {
          Log::Info("Using column %s as label", name.c_str());
        } else {
80
81
          Log::Fatal("Could not find label column %s in data file \n"
                     "or data file doesn't contain header", name.c_str());
Guolin Ke's avatar
Guolin Ke committed
82
        }
Guolin Ke's avatar
Guolin Ke committed
83
      } else {
Guolin Ke's avatar
Guolin Ke committed
84
        if (!Common::AtoiAndCheck(config_.label_column.c_str(), &label_idx_)) {
85
86
87
          Log::Fatal("label_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
88
89
        }
        Log::Info("Using column number %d as label", label_idx_);
Guolin Ke's avatar
Guolin Ke committed
90
91
      }
    }
Guolin Ke's avatar
Guolin Ke committed
92

93
94
95
96
97
98
99
100
101
    if (!config_.parser_config_file.empty()) {
      // if parser config file exists, feature names may be changed after customized parser applied.
      // clear here so could use default filled feature names during dataset construction.
      // may improve by saving real feature names defined in parser in the future.
      if (!feature_names_.empty()) {
        feature_names_.clear();
      }
    }

Guolin Ke's avatar
Guolin Ke committed
102
    if (!feature_names_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
103
104
105
106
      // erase label column name
      feature_names_.erase(feature_names_.begin() + label_idx_);
      for (size_t i = 0; i < feature_names_.size(); ++i) {
        name2idx[feature_names_[i]] = static_cast<int>(i);
Guolin Ke's avatar
Guolin Ke committed
107
      }
Guolin Ke's avatar
Guolin Ke committed
108
109
110
    }

    // load ignore columns
Guolin Ke's avatar
Guolin Ke committed
111
112
113
    if (config_.ignore_column.size() > 0) {
      if (Common::StartsWith(config_.ignore_column, name_prefix)) {
        std::string names = config_.ignore_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
114
115
116
117
118
119
120
121
122
        for (auto name : Common::Split(names.c_str(), ',')) {
          if (name2idx.count(name) > 0) {
            int tmp = name2idx[name];
            ignore_features_.emplace(tmp);
          } else {
            Log::Fatal("Could not find ignore column %s in data file", name.c_str());
          }
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
123
        for (auto token : Common::Split(config_.ignore_column.c_str(), ',')) {
Guolin Ke's avatar
Guolin Ke committed
124
125
          int tmp = 0;
          if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
126
127
128
            Log::Fatal("ignore_column is not a number,\n"
                       "if you want to use a column name,\n"
                       "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
129
130
          }
          ignore_features_.emplace(tmp);
Guolin Ke's avatar
Guolin Ke committed
131
132
133
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
134
    // load weight idx
Guolin Ke's avatar
Guolin Ke committed
135
136
137
    if (config_.weight_column.size() > 0) {
      if (Common::StartsWith(config_.weight_column, name_prefix)) {
        std::string name = config_.weight_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
138
139
140
141
142
143
        if (name2idx.count(name) > 0) {
          weight_idx_ = name2idx[name];
          Log::Info("Using column %s as weight", name.c_str());
        } else {
          Log::Fatal("Could not find weight column %s in data file", name.c_str());
        }
Guolin Ke's avatar
Guolin Ke committed
144
      } else {
Guolin Ke's avatar
Guolin Ke committed
145
        if (!Common::AtoiAndCheck(config_.weight_column.c_str(), &weight_idx_)) {
146
147
148
          Log::Fatal("weight_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
149
150
        }
        Log::Info("Using column number %d as weight", weight_idx_);
Guolin Ke's avatar
Guolin Ke committed
151
      }
Guolin Ke's avatar
Guolin Ke committed
152
      ignore_features_.emplace(weight_idx_);
Guolin Ke's avatar
Guolin Ke committed
153
    }
Guolin Ke's avatar
Guolin Ke committed
154
    // load group idx
Guolin Ke's avatar
Guolin Ke committed
155
156
157
    if (config_.group_column.size() > 0) {
      if (Common::StartsWith(config_.group_column, name_prefix)) {
        std::string name = config_.group_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
158
159
160
161
162
163
164
        if (name2idx.count(name) > 0) {
          group_idx_ = name2idx[name];
          Log::Info("Using column %s as group/query id", name.c_str());
        } else {
          Log::Fatal("Could not find group/query column %s in data file", name.c_str());
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
165
        if (!Common::AtoiAndCheck(config_.group_column.c_str(), &group_idx_)) {
166
167
168
          Log::Fatal("group_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
169
170
171
172
        }
        Log::Info("Using column number %d as group/query id", group_idx_);
      }
      ignore_features_.emplace(group_idx_);
Guolin Ke's avatar
Guolin Ke committed
173
174
    }
  }
Guolin Ke's avatar
Guolin Ke committed
175
176
177
  if (config_.categorical_feature.size() > 0) {
    if (Common::StartsWith(config_.categorical_feature, name_prefix)) {
      std::string names = config_.categorical_feature.substr(name_prefix.size());
178
179
180
181
182
      for (auto name : Common::Split(names.c_str(), ',')) {
        if (name2idx.count(name) > 0) {
          int tmp = name2idx[name];
          categorical_features_.emplace(tmp);
        } else {
Guolin Ke's avatar
Guolin Ke committed
183
          Log::Fatal("Could not find categorical_feature %s in data file", name.c_str());
184
185
186
        }
      }
    } else {
Guolin Ke's avatar
Guolin Ke committed
187
      for (auto token : Common::Split(config_.categorical_feature.c_str(), ',')) {
188
189
        int tmp = 0;
        if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
Guolin Ke's avatar
Guolin Ke committed
190
          Log::Fatal("categorical_feature is not a number,\n"
191
192
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
193
194
195
196
197
        }
        categorical_features_.emplace(tmp);
      }
    }
  }
Guolin Ke's avatar
Guolin Ke committed
198
199
}

200
201
202
203
204
205
206
207
208
209
void CheckSampleSize(size_t sample_cnt, size_t num_data) {
  if (static_cast<double>(sample_cnt) / num_data < 0.2f &&
      sample_cnt < 100000) {
    Log::Warning(
        "Using too small ``bin_construct_sample_cnt`` may encounter "
        "unexpected "
        "errors and poor accuracy.");
  }
}

210
Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_machines) {
211
  // don't support query id in data file when using distributed training
Guolin Ke's avatar
Guolin Ke committed
212
  if (num_machines > 1 && !config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
213
    if (group_idx_ > 0) {
214
      Log::Fatal("Using a query id without pre-partitioning the data file is not supported for distributed training.\n"
215
                 "Please use an additional query file or pre-partition the data");
Guolin Ke's avatar
Guolin Ke committed
216
217
    }
  }
Guolin Ke's avatar
Guolin Ke committed
218
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
219
220
221
  if (store_raw_) {
    dataset->SetHasRaw(true);
  }
Guolin Ke's avatar
Guolin Ke committed
222
223
  data_size_t num_global_data = 0;
  std::vector<data_size_t> used_data_indices;
224
  auto bin_filename = CheckCanLoadFromBin(filename);
225
  bool is_load_from_binary = false;
226
  if (bin_filename.size() == 0) {
227
    dataset->parser_config_str_ = Parser::GenerateParserConfigStr(filename, config_.parser_config_file.c_str(), config_.header, label_idx_);
Chen Yufei's avatar
Chen Yufei committed
228
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_,
229
                                                               config_.precise_float_parser, dataset->parser_config_str_));
Guolin Ke's avatar
Guolin Ke committed
230
231
232
233
    if (parser == nullptr) {
      Log::Fatal("Could not recognize data format of %s", filename);
    }
    dataset->data_filename_ = filename;
Guolin Ke's avatar
Guolin Ke committed
234
    dataset->label_idx_ = label_idx_;
235
    dataset->metadata_.Init(filename);
Guolin Ke's avatar
Guolin Ke committed
236
    if (!config_.two_round) {
Guolin Ke's avatar
Guolin Ke committed
237
      // read data to memory
238
      auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
239
240
241
      dataset->num_data_ = static_cast<data_size_t>(text_data.size());
      // sample data
      auto sample_data = SampleTextDataFromMemory(text_data);
242
243
      CheckSampleSize(sample_data.size(),
                      static_cast<size_t>(dataset->num_data_));
244
      // construct feature bin mappers & clear sample data
Guolin Ke's avatar
Guolin Ke committed
245
      ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get());
246
      std::vector<std::string>().swap(sample_data);
247
248
249
      if (dataset->has_raw()) {
        dataset->ResizeRaw(dataset->num_data_);
      }
Guolin Ke's avatar
Guolin Ke committed
250
      // initialize label
251
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
252
      // extract features
Guolin Ke's avatar
Guolin Ke committed
253
      ExtractFeaturesFromMemory(&text_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
254
255
256
257
258
259
260
261
262
      text_data.clear();
    } else {
      // sample data from file
      auto sample_data = SampleTextDataFromFile(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
      if (used_data_indices.size() > 0) {
        dataset->num_data_ = static_cast<data_size_t>(used_data_indices.size());
      } else {
        dataset->num_data_ = num_global_data;
      }
263
264
      CheckSampleSize(sample_data.size(),
                      static_cast<size_t>(dataset->num_data_));
265
      // construct feature bin mappers & clear sample data
Guolin Ke's avatar
Guolin Ke committed
266
      ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get());
267
      std::vector<std::string>().swap(sample_data);
268
269
270
      if (dataset->has_raw()) {
        dataset->ResizeRaw(dataset->num_data_);
      }
Guolin Ke's avatar
Guolin Ke committed
271
      // initialize label
272
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
273
      Log::Info("Making second pass...");
Guolin Ke's avatar
Guolin Ke committed
274
      // extract features
Guolin Ke's avatar
Guolin Ke committed
275
      ExtractFeaturesFromFile(filename, parser.get(), used_data_indices, dataset.get());
Guolin Ke's avatar
Guolin Ke committed
276
277
278
    }
  } else {
    // load data from binary file
279
280
    is_load_from_binary = true;
    Log::Info("Load from binary file %s", bin_filename.c_str());
281
    dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), rank, num_machines, &num_global_data, &used_data_indices));
282
283

    // checks whether there's a initial score file when loaded from binary data files
284
    // the initial score file should with suffix ".bin.init"
285
286
    dataset->metadata_.LoadInitialScore(bin_filename);

287
288
    dataset->device_type_ = config_.device_type;
    dataset->gpu_device_id_ = config_.gpu_device_id;
289
290
    #ifdef USE_CUDA
    if (config_.device_type == std::string("cuda")) {
291
292
293
294
295
      dataset->CreateCUDAColumnData();
      dataset->metadata_.CreateCUDAMetadata(dataset->gpu_device_id_);
    } else {
      dataset->cuda_column_data_ = nullptr;
    }
296
    #endif  // USE_CUDA
Guolin Ke's avatar
Guolin Ke committed
297
298
299
300
  }
  // check meta data
  dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);
  // need to check training data
301
302
  CheckDataset(dataset.get(), is_load_from_binary);

Guolin Ke's avatar
Guolin Ke committed
303
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
304
305
}

306
Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, const Dataset* train_data) {
Guolin Ke's avatar
Guolin Ke committed
307
308
  data_size_t num_global_data = 0;
  std::vector<data_size_t> used_data_indices;
Guolin Ke's avatar
Guolin Ke committed
309
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
310
311
312
  if (store_raw_) {
    dataset->SetHasRaw(true);
  }
313
314
  auto bin_filename = CheckCanLoadFromBin(filename);
  if (bin_filename.size() == 0) {
Chen Yufei's avatar
Chen Yufei committed
315
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_,
316
                                                               config_.precise_float_parser, train_data->parser_config_str_));
Guolin Ke's avatar
Guolin Ke committed
317
318
319
320
    if (parser == nullptr) {
      Log::Fatal("Could not recognize data format of %s", filename);
    }
    dataset->data_filename_ = filename;
Guolin Ke's avatar
Guolin Ke committed
321
    dataset->label_idx_ = label_idx_;
322
    dataset->metadata_.Init(filename);
Guolin Ke's avatar
Guolin Ke committed
323
    if (!config_.two_round) {
Guolin Ke's avatar
Guolin Ke committed
324
325
326
327
      // read data in memory
      auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, 0, 1, &num_global_data, &used_data_indices);
      dataset->num_data_ = static_cast<data_size_t>(text_data.size());
      // initialize label
328
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
329
      dataset->CreateValid(train_data);
330
331
332
      if (dataset->has_raw()) {
        dataset->ResizeRaw(dataset->num_data_);
      }
Guolin Ke's avatar
Guolin Ke committed
333
      // extract features
Guolin Ke's avatar
Guolin Ke committed
334
      ExtractFeaturesFromMemory(&text_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
335
336
      text_data.clear();
    } else {
Guolin Ke's avatar
Guolin Ke committed
337
      TextReader<data_size_t> text_reader(filename, config_.header);
Guolin Ke's avatar
Guolin Ke committed
338
339
340
341
      // Get number of lines of data file
      dataset->num_data_ = static_cast<data_size_t>(text_reader.CountLine());
      num_global_data = dataset->num_data_;
      // initialize label
342
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
343
      dataset->CreateValid(train_data);
344
345
346
      if (dataset->has_raw()) {
        dataset->ResizeRaw(dataset->num_data_);
      }
Guolin Ke's avatar
Guolin Ke committed
347
      // extract features
Guolin Ke's avatar
Guolin Ke committed
348
      ExtractFeaturesFromFile(filename, parser.get(), used_data_indices, dataset.get());
Guolin Ke's avatar
Guolin Ke committed
349
350
351
    }
  } else {
    // load data from binary file
352
    dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), 0, 1, &num_global_data, &used_data_indices));
353
    // checks whether there's a initial score file when loaded from binary data files
354
    // the initial score file should with suffix ".bin.init"
355
    dataset->metadata_.LoadInitialScore(bin_filename);
Guolin Ke's avatar
Guolin Ke committed
356
357
358
359
  }
  // not need to check validation data
  // check meta data
  dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
360
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
361
362
}

363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
Dataset* DatasetLoader::LoadFromSerializedReference(const char* binary_data, size_t buffer_size, data_size_t num_data, int32_t num_classes) {
  auto dataset = std::unique_ptr<Dataset>(new Dataset(num_data));

  auto mem_ptr = binary_data;

  // check token
  const size_t size_of_token = std::strlen(Dataset::binary_serialized_reference_token);
  size_t size_of_token_in_input = VirtualFileWriter::AlignedSize(sizeof(char) * size_of_token);
  if (buffer_size < size_of_token_in_input) {
    Log::Fatal("Binary definition file error: token has the wrong size");
  }
  if (std::string(mem_ptr, size_of_token) != std::string(Dataset::binary_serialized_reference_token)) {
    Log::Fatal("Input file is not LightGBM binary reference file");
  }
  mem_ptr += size_of_token_in_input;

  size_t size_of_version = VirtualFileWriter::AlignedSize(Dataset::kSerializedReferenceVersionLength);
  std::string version(mem_ptr, Dataset::kSerializedReferenceVersionLength);
  if (version != std::string(Dataset::serialized_reference_version)) {
    Log::Fatal("Unexpected version of serialized binary data: %s", version.c_str());
  }
  mem_ptr += size_of_version;

  size_t size_of_header = *(reinterpret_cast<const size_t*>(mem_ptr));
  mem_ptr += sizeof(size_t);

  LoadHeaderFromMemory(dataset.get(), mem_ptr);
  dataset->num_data_ = num_data;  // update to the given num_data
  mem_ptr += size_of_header;

  // read feature group definitions
  for (int i = 0; i < dataset->num_groups_; ++i) {
    // read feature size
    const size_t size_of_feature = *(reinterpret_cast<const size_t*>(mem_ptr));
    mem_ptr += sizeof(size_t);
    dataset->feature_groups_.emplace_back(std::unique_ptr<FeatureGroup>(new FeatureGroup(mem_ptr, num_data, i)));
    mem_ptr += size_of_feature;
  }
  dataset->feature_groups_.shrink_to_fit();

  dataset->numeric_feature_map_ = std::vector<int>(dataset->num_features_, false);
  dataset->num_numeric_features_ = 0;
  for (int i = 0; i < dataset->num_features_; ++i) {
    if (dataset->FeatureBinMapper(i)->bin_type() == BinType::CategoricalBin) {
      dataset->numeric_feature_map_[i] = -1;
    } else {
      dataset->numeric_feature_map_[i] = dataset->num_numeric_features_;
      ++dataset->num_numeric_features_;
    }
  }

  int has_weights = config_.weight_column.size() > 0;
  int has_init_scores = num_classes > 0;
  int has_queries = config_.group_column.size() > 0;
  dataset->metadata_.Init(num_data, has_weights, has_init_scores, has_queries, num_classes);

  Log::Info("Loaded reference dataset: %d features, %d num_data", dataset->num_features_, num_data);

  return dataset.release();
}

424
425
426
Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* bin_filename,
                                        int rank, int num_machines, int* num_global_data,
                                        std::vector<data_size_t>* used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
427
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
428
  auto reader = VirtualFileReader::Make(bin_filename);
Guolin Ke's avatar
Guolin Ke committed
429
  dataset->data_filename_ = data_filename;
430
  if (!reader->Init()) {
Guolin Ke's avatar
Guolin Ke committed
431
432
433
434
435
    Log::Fatal("Could not read binary data from %s", bin_filename);
  }

  // buffer to read binary file
  size_t buffer_size = 16 * 1024 * 1024;
Guolin Ke's avatar
Guolin Ke committed
436
  auto buffer = std::vector<char>(buffer_size);
437

438
439
  // check token
  size_t size_of_token = std::strlen(Dataset::binary_file_token);
440
441
442
443
  size_t read_cnt = reader->Read(
      buffer.data(),
      VirtualFileWriter::AlignedSize(sizeof(char) * size_of_token));
  if (read_cnt < sizeof(char) * size_of_token) {
444
445
446
    Log::Fatal("Binary file error: token has the wrong size");
  }
  if (std::string(buffer.data()) != std::string(Dataset::binary_file_token)) {
447
    Log::Fatal("Input file is not LightGBM binary file");
448
  }
Guolin Ke's avatar
Guolin Ke committed
449
450

  // read size of header
451
  read_cnt = reader->Read(buffer.data(), sizeof(size_t));
Guolin Ke's avatar
Guolin Ke committed
452

453
  if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
454
455
456
    Log::Fatal("Binary file error: header has the wrong size");
  }

Guolin Ke's avatar
Guolin Ke committed
457
  size_t size_of_head = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
458

459
  // re-allocate space if not enough
Guolin Ke's avatar
Guolin Ke committed
460
461
  if (size_of_head > buffer_size) {
    buffer_size = size_of_head;
Guolin Ke's avatar
Guolin Ke committed
462
    buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
463
464
  }
  // read header
465
  read_cnt = reader->Read(buffer.data(), size_of_head);
Guolin Ke's avatar
Guolin Ke committed
466
467
468
469
470

  if (read_cnt != size_of_head) {
    Log::Fatal("Binary file error: header is incorrect");
  }
  // get header
Guolin Ke's avatar
Guolin Ke committed
471
  const char* mem_ptr = buffer.data();
472
  LoadHeaderFromMemory(dataset.get(), mem_ptr);
Guolin Ke's avatar
Guolin Ke committed
473
474

  // read size of meta data
475
  read_cnt = reader->Read(buffer.data(), sizeof(size_t));
Guolin Ke's avatar
Guolin Ke committed
476

477
  if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
478
479
480
    Log::Fatal("Binary file error: meta data has the wrong size");
  }

Guolin Ke's avatar
Guolin Ke committed
481
  size_t size_of_metadata = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
482
483
484
485

  // re-allocate space if not enough
  if (size_of_metadata > buffer_size) {
    buffer_size = size_of_metadata;
Guolin Ke's avatar
Guolin Ke committed
486
    buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
487
488
  }
  //  read meta data
489
  read_cnt = reader->Read(buffer.data(), size_of_metadata);
Guolin Ke's avatar
Guolin Ke committed
490
491
492
493
494

  if (read_cnt != size_of_metadata) {
    Log::Fatal("Binary file error: meta data is incorrect");
  }
  // load meta data
Guolin Ke's avatar
Guolin Ke committed
495
  dataset->metadata_.LoadFromMemory(buffer.data());
Guolin Ke's avatar
Guolin Ke committed
496

497
498
  *num_global_data = dataset->num_data_;
  used_data_indices->clear();
Guolin Ke's avatar
Guolin Ke committed
499
  // sample local used data if need to partition
Guolin Ke's avatar
Guolin Ke committed
500
  if (num_machines > 1 && !config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
501
502
503
504
    const data_size_t* query_boundaries = dataset->metadata_.query_boundaries();
    if (query_boundaries == nullptr) {
      // if not contain query file, minimal sample unit is one record
      for (data_size_t i = 0; i < dataset->num_data_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
505
        if (random_.NextShort(0, num_machines) == rank) {
506
          used_data_indices->push_back(i);
Guolin Ke's avatar
Guolin Ke committed
507
508
509
510
511
512
513
514
515
        }
      }
    } else {
      // if contain query file, minimal sample unit is one query
      data_size_t num_queries = dataset->metadata_.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      for (data_size_t i = 0; i < dataset->num_data_; ++i) {
        if (qid >= num_queries) {
516
517
          Log::Fatal("Current query exceeds the range of the query file,\n"
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
518
519
520
521
        }
        if (i >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
522
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
523
524
525
526
527
            is_query_used = true;
          }
          ++qid;
        }
        if (is_query_used) {
528
          used_data_indices->push_back(i);
Guolin Ke's avatar
Guolin Ke committed
529
530
531
        }
      }
    }
532
    dataset->num_data_ = static_cast<data_size_t>((*used_data_indices).size());
Guolin Ke's avatar
Guolin Ke committed
533
  }
534
  dataset->metadata_.PartitionLabel(*used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
535
  // read feature data
Guolin Ke's avatar
Guolin Ke committed
536
  for (int i = 0; i < dataset->num_groups_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
537
    // read feature size
538
539
    read_cnt = reader->Read(buffer.data(), sizeof(size_t));
    if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
540
541
      Log::Fatal("Binary file error: feature %d has the wrong size", i);
    }
Guolin Ke's avatar
Guolin Ke committed
542
    size_t size_of_feature = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
543
544
545
    // re-allocate space if not enough
    if (size_of_feature > buffer_size) {
      buffer_size = size_of_feature;
Guolin Ke's avatar
Guolin Ke committed
546
      buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
547
548
    }

549
    read_cnt = reader->Read(buffer.data(), size_of_feature);
Guolin Ke's avatar
Guolin Ke committed
550
551

    if (read_cnt != size_of_feature) {
552
      Log::Fatal("Binary file error: feature %d is incorrect, read count: %zu", i, read_cnt);
Guolin Ke's avatar
Guolin Ke committed
553
    }
Guolin Ke's avatar
Guolin Ke committed
554
    dataset->feature_groups_.emplace_back(std::unique_ptr<FeatureGroup>(
555
556
      new FeatureGroup(buffer.data(),
                       *num_global_data,
557
                       *used_data_indices, i)));
Guolin Ke's avatar
Guolin Ke committed
558
  }
Guolin Ke's avatar
Guolin Ke committed
559
  dataset->feature_groups_.shrink_to_fit();
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581

  // raw data
  dataset->numeric_feature_map_ = std::vector<int>(dataset->num_features_, false);
  dataset->num_numeric_features_ = 0;
  for (int i = 0; i < dataset->num_features_; ++i) {
    if (dataset->FeatureBinMapper(i)->bin_type() == BinType::CategoricalBin) {
      dataset->numeric_feature_map_[i] = -1;
    } else {
      dataset->numeric_feature_map_[i] = dataset->num_numeric_features_;
      ++dataset->num_numeric_features_;
    }
  }
  if (dataset->has_raw()) {
    dataset->ResizeRaw(dataset->num_data());
      size_t row_size = dataset->num_numeric_features_ * sizeof(float);
      if (row_size > buffer_size) {
        buffer_size = row_size;
        buffer.resize(buffer_size);
      }
    for (int i = 0; i < dataset->num_data(); ++i) {
      read_cnt = reader->Read(buffer.data(), row_size);
      if (read_cnt != row_size) {
582
        Log::Fatal("Binary file error: row %d of raw data is incorrect, read count: %zu", i, read_cnt);
583
584
585
586
587
588
589
590
591
592
593
594
595
      }
      mem_ptr = buffer.data();
      const float* tmp_ptr_raw_row = reinterpret_cast<const float*>(mem_ptr);
      for (int j = 0; j < dataset->num_features(); ++j) {
        int feat_ind = dataset->numeric_feature_map_[j];
        if (feat_ind >= 0) {
          dataset->raw_data_[feat_ind][i] = tmp_ptr_raw_row[feat_ind];
        }
      }
      mem_ptr += row_size;
    }
  }

Guolin Ke's avatar
Guolin Ke committed
596
  dataset->is_finish_load_ = true;
Guolin Ke's avatar
Guolin Ke committed
597
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
598
599
}

600
Dataset* DatasetLoader::ConstructFromSampleData(double** sample_values,
601
602
603
604
605
606
607
                                                int** sample_indices,
                                                int num_col,
                                                const int* num_per_col,
                                                size_t total_sample_size,
                                                data_size_t num_local_data,
                                                int64_t num_dist_data) {
  CheckSampleSize(total_sample_size, static_cast<size_t>(num_dist_data));
608
609
610
611
612
  int num_total_features = num_col;
  if (Network::num_machines() > 1) {
    num_total_features = Network::GlobalSyncUpByMax(num_total_features);
  }
  std::vector<std::unique_ptr<BinMapper>> bin_mappers(num_total_features);
613
614
  // fill feature_names_ if not header
  if (feature_names_.empty()) {
615
    for (int i = 0; i < num_col; ++i) {
616
617
618
619
620
      std::stringstream str_buf;
      str_buf << "Column_" << i;
      feature_names_.push_back(str_buf.str());
    }
  }
Belinda Trotta's avatar
Belinda Trotta committed
621
  if (!config_.max_bin_by_feature.empty()) {
622
623
    CHECK_EQ(static_cast<size_t>(num_col), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
Belinda Trotta's avatar
Belinda Trotta committed
624
  }
625
626
627
628
629

  // get forced split
  std::string forced_bins_path = config_.forcedbins_filename;
  std::vector<std::vector<double>> forced_bin_bounds = DatasetLoader::GetForcedBins(forced_bins_path, num_col, categorical_features_);

Guolin Ke's avatar
Guolin Ke committed
630
  const data_size_t filter_cnt = static_cast<data_size_t>(
631
    static_cast<double>(config_.min_data_in_leaf * total_sample_size) / num_dist_data);
632
633
634
  if (Network::num_machines() == 1) {
    // if only one machine, find bin locally
    OMP_INIT_EX();
635
    #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(guided)
636
637
638
639
640
641
642
643
644
    for (int i = 0; i < num_col; ++i) {
      OMP_LOOP_EX_BEGIN();
      if (ignore_features_.count(i) > 0) {
        bin_mappers[i] = nullptr;
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(i)) {
        bin_type = BinType::CategoricalBin;
645
646
647
648
        bool feat_is_unconstrained = ((config_.monotone_constraints.size() == 0) || (config_.monotone_constraints[i] == 0));
        if (!feat_is_unconstrained) {
            Log::Fatal("The output cannot be monotone with respect to categorical features");
        }
649
650
      }
      bin_mappers[i].reset(new BinMapper());
Belinda Trotta's avatar
Belinda Trotta committed
651
652
      if (config_.max_bin_by_feature.empty()) {
        bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
653
                                config_.max_bin, config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter,
654
655
                                bin_type, config_.use_missing, config_.zero_as_missing,
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
656
657
      } else {
        bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
658
                                config_.max_bin_by_feature[i], config_.min_data_in_bin,
659
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
660
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
661
      }
662
663
664
665
666
667
668
669
670
671
672
673
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();
  } else {
    // if have multi-machines, need to find bin distributed
    // different machines will find bin for different features
    int num_machines = Network::num_machines();
    int rank = Network::rank();
    // start and len will store the process feature indices for different machines
    // machine i will find bins for features in [ start[i], start[i] + len[i] )
    std::vector<int> start(num_machines);
    std::vector<int> len(num_machines);
674
    int step = (num_total_features + num_machines - 1) / num_machines;
675
676
677
    if (step < 1) {
      step = 1;
    }
678
679
680

    start[0] = 0;
    for (int i = 0; i < num_machines - 1; ++i) {
681
      len[i] = std::min(step, num_total_features - start[i]);
682
683
      start[i + 1] = start[i] + len[i];
    }
684
    len[num_machines - 1] = num_total_features - start[num_machines - 1];
685
    OMP_INIT_EX();
686
    #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(guided)
687
688
689
690
691
692
693
694
695
696
    for (int i = 0; i < len[rank]; ++i) {
      OMP_LOOP_EX_BEGIN();
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(start[rank] + i)) {
        bin_type = BinType::CategoricalBin;
      }
      bin_mappers[i].reset(new BinMapper());
697
698
699
      if (num_col <= start[rank] + i) {
        continue;
      }
Belinda Trotta's avatar
Belinda Trotta committed
700
      if (config_.max_bin_by_feature.empty()) {
701
702
        bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
                                total_sample_size, config_.max_bin, config_.min_data_in_bin,
703
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
704
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
705
      } else {
706
707
        bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
                                total_sample_size, config_.max_bin_by_feature[start[rank] + i],
708
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
709
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
710
      }
711
712
      OMP_LOOP_EX_END();
    }
Guolin Ke's avatar
Guolin Ke committed
713
    OMP_THROW_EX();
714
    comm_size_t self_buf_size = 0;
715
    for (int i = 0; i < len[rank]; ++i) {
716
717
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
718
      }
719
      self_buf_size += static_cast<comm_size_t>(bin_mappers[i]->SizesInByte());
Guolin Ke's avatar
Guolin Ke committed
720
    }
721
722
    std::vector<char> input_buffer(self_buf_size);
    auto cp_ptr = input_buffer.data();
723
724
725
726
    for (int i = 0; i < len[rank]; ++i) {
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
727
728
      bin_mappers[i]->CopyTo(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
729
730
731
      // free
      bin_mappers[i].reset(nullptr);
    }
732
733
734
735
    std::vector<comm_size_t> size_len = Network::GlobalArray(self_buf_size);
    std::vector<comm_size_t> size_start(num_machines, 0);
    for (int i = 1; i < num_machines; ++i) {
      size_start[i] = size_start[i - 1] + size_len[i - 1];
736
    }
737
738
    comm_size_t total_buffer_size = size_start[num_machines - 1] + size_len[num_machines - 1];
    std::vector<char> output_buffer(total_buffer_size);
739
    // gather global feature bin mappers
740
741
    Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), total_buffer_size);
    cp_ptr = output_buffer.data();
742
    // restore features bins from buffer
743
    for (int i = 0; i < num_total_features; ++i) {
744
745
746
747
748
      if (ignore_features_.count(i) > 0) {
        bin_mappers[i] = nullptr;
        continue;
      }
      bin_mappers[i].reset(new BinMapper());
749
750
      bin_mappers[i]->CopyFrom(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
751
    }
Guolin Ke's avatar
Guolin Ke committed
752
  }
753
  CheckCategoricalFeatureNumBin(bin_mappers, config_.max_bin, config_.max_bin_by_feature);
754
  auto dataset = std::unique_ptr<Dataset>(new Dataset(num_local_data));
Guolin Ke's avatar
Guolin Ke committed
755
  dataset->Construct(&bin_mappers, num_total_features, forced_bin_bounds, sample_indices, sample_values, num_per_col, num_col, total_sample_size, config_);
756
  if (dataset->has_raw()) {
757
    dataset->ResizeRaw(num_local_data);
758
  }
759
  dataset->set_feature_names(feature_names_);
Guolin Ke's avatar
Guolin Ke committed
760
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
761
}
Guolin Ke's avatar
Guolin Ke committed
762
763
764
765


// ---- private functions ----

766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
void DatasetLoader::LoadHeaderFromMemory(Dataset* dataset, const char* buffer) {
  // get header
  const char* mem_ptr = buffer;
  dataset->num_data_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->num_data_));
  dataset->num_features_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->num_features_));
  dataset->num_total_features_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->num_total_features_));
  dataset->label_idx_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->label_idx_));
  dataset->max_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->max_bin_));
  dataset->bin_construct_sample_cnt_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->bin_construct_sample_cnt_));
  dataset->min_data_in_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->min_data_in_bin_));
  dataset->use_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->use_missing_));
  dataset->zero_as_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->zero_as_missing_));
  dataset->has_raw_ = *(reinterpret_cast<const bool*>(mem_ptr));

  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->has_raw_));
  const int* tmp_feature_map = reinterpret_cast<const int*>(mem_ptr);
  dataset->used_feature_map_.clear();
  for (int i = 0; i < dataset->num_total_features_; ++i) {
    dataset->used_feature_map_.push_back(tmp_feature_map[i]);
  }
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int) * dataset->num_total_features_);
  // num_groups
  dataset->num_groups_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->num_groups_));
  // real_feature_idx_
  const int* tmp_ptr_real_feature_idx_ = reinterpret_cast<const int*>(mem_ptr);
  dataset->real_feature_idx_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->real_feature_idx_.push_back(tmp_ptr_real_feature_idx_[i]);
  }
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int) * dataset->num_features_);
  // feature2group
  const int* tmp_ptr_feature2group = reinterpret_cast<const int*>(mem_ptr);
  dataset->feature2group_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->feature2group_.push_back(tmp_ptr_feature2group[i]);
  }
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int) * dataset->num_features_);
  // feature2subfeature
  const int* tmp_ptr_feature2subfeature = reinterpret_cast<const int*>(mem_ptr);
  dataset->feature2subfeature_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->feature2subfeature_.push_back(tmp_ptr_feature2subfeature[i]);
  }
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int) * dataset->num_features_);
  // group_bin_boundaries
  const uint64_t* tmp_ptr_group_bin_boundaries = reinterpret_cast<const uint64_t*>(mem_ptr);
  dataset->group_bin_boundaries_.clear();
  for (int i = 0; i < dataset->num_groups_ + 1; ++i) {
    dataset->group_bin_boundaries_.push_back(tmp_ptr_group_bin_boundaries[i]);
  }
  mem_ptr += sizeof(uint64_t) * (dataset->num_groups_ + 1);

  // group_feature_start_
  const int* tmp_ptr_group_feature_start = reinterpret_cast<const int*>(mem_ptr);
  dataset->group_feature_start_.clear();
  for (int i = 0; i < dataset->num_groups_; ++i) {
    dataset->group_feature_start_.push_back(tmp_ptr_group_feature_start[i]);
  }
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int) * (dataset->num_groups_));

  // group_feature_cnt_
  const int* tmp_ptr_group_feature_cnt = reinterpret_cast<const int*>(mem_ptr);
  dataset->group_feature_cnt_.clear();
  for (int i = 0; i < dataset->num_groups_; ++i) {
    dataset->group_feature_cnt_.push_back(tmp_ptr_group_feature_cnt[i]);
  }
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int) * (dataset->num_groups_));

  if (!config_.max_bin_by_feature.empty()) {
    CHECK_EQ(static_cast<size_t>(dataset->num_total_features_), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
    dataset->max_bin_by_feature_.resize(dataset->num_total_features_);
    dataset->max_bin_by_feature_.assign(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end());
  } else {
    const int32_t* tmp_ptr_max_bin_by_feature = reinterpret_cast<const int32_t*>(mem_ptr);
    dataset->max_bin_by_feature_.clear();
    for (int i = 0; i < dataset->num_total_features_; ++i) {
      dataset->max_bin_by_feature_.push_back(tmp_ptr_max_bin_by_feature[i]);
    }
  }
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int32_t) * (dataset->num_total_features_));
  if (ArrayArgs<int32_t>::CheckAll(dataset->max_bin_by_feature_, -1)) {
    dataset->max_bin_by_feature_.clear();
  }

  // get feature names
  dataset->feature_names_.clear();
  for (int i = 0; i < dataset->num_total_features_; ++i) {
    int str_len = *(reinterpret_cast<const int*>(mem_ptr));
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int));
    std::stringstream str_buf;
    auto tmp_arr = reinterpret_cast<const char*>(mem_ptr);
    for (int j = 0; j < str_len; ++j) {
      char tmp_char = tmp_arr[j];
      str_buf << tmp_char;
    }
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(char) * str_len);
    dataset->feature_names_.emplace_back(str_buf.str());
  }
  // get forced_bin_bounds_
  dataset->forced_bin_bounds_ = std::vector<std::vector<double>>(dataset->num_total_features_, std::vector<double>());
  for (int i = 0; i < dataset->num_total_features_; ++i) {
    int num_bounds = *(reinterpret_cast<const int*>(mem_ptr));
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int));
    dataset->forced_bin_bounds_[i] = std::vector<double>();
    const double* tmp_ptr_forced_bounds =
      reinterpret_cast<const double*>(mem_ptr);
    for (int j = 0; j < num_bounds; ++j) {
      double bound = tmp_ptr_forced_bounds[j];
      dataset->forced_bin_bounds_[i].push_back(bound);
    }
    mem_ptr += num_bounds * sizeof(double);
  }
}

891
void DatasetLoader::CheckDataset(const Dataset* dataset, bool is_load_from_binary) {
Guolin Ke's avatar
Guolin Ke committed
892
  if (dataset->num_data_ <= 0) {
Guolin Ke's avatar
Guolin Ke committed
893
    Log::Fatal("Data file %s is empty", dataset->data_filename_.c_str());
Guolin Ke's avatar
Guolin Ke committed
894
  }
895
896
  if (dataset->feature_names_.size() != static_cast<size_t>(dataset->num_total_features_)) {
    Log::Fatal("Size of feature name error, should be %d, got %d", dataset->num_total_features_,
897
               static_cast<int>(dataset->feature_names_.size()));
898
  }
Guolin Ke's avatar
Guolin Ke committed
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
  bool is_feature_order_by_group = true;
  int last_group = -1;
  int last_sub_feature = -1;
  // if features are ordered, not need to use hist_buf
  for (int i = 0; i < dataset->num_features_; ++i) {
    int group = dataset->feature2group_[i];
    int sub_feature = dataset->feature2subfeature_[i];
    if (group < last_group) {
      is_feature_order_by_group = false;
    } else if (group == last_group) {
      if (sub_feature <= last_sub_feature) {
        is_feature_order_by_group = false;
        break;
      }
    }
    last_group = group;
    last_sub_feature = sub_feature;
  }
  if (!is_feature_order_by_group) {
918
    Log::Fatal("Features in dataset should be ordered by group");
Guolin Ke's avatar
Guolin Ke committed
919
  }
920
921
922

  if (is_load_from_binary) {
    if (dataset->max_bin_ != config_.max_bin) {
923
924
      Log::Fatal("Dataset was constructed with parameter max_bin=%d. It cannot be changed to %d when loading from binary file.",
                 dataset->max_bin_, config_.max_bin);
925
926
    }
    if (dataset->min_data_in_bin_ != config_.min_data_in_bin) {
927
928
      Log::Fatal("Dataset was constructed with parameter min_data_in_bin=%d. It cannot be changed to %d when loading from binary file.",
                 dataset->min_data_in_bin_, config_.min_data_in_bin);
929
930
    }
    if (dataset->use_missing_ != config_.use_missing) {
931
932
      Log::Fatal("Dataset was constructed with parameter use_missing=%d. It cannot be changed to %d when loading from binary file.",
                 dataset->use_missing_, config_.use_missing);
933
934
    }
    if (dataset->zero_as_missing_ != config_.zero_as_missing) {
935
936
      Log::Fatal("Dataset was constructed with parameter zero_as_missing=%d. It cannot be changed to %d when loading from binary file.",
                 dataset->zero_as_missing_, config_.zero_as_missing);
937
938
    }
    if (dataset->bin_construct_sample_cnt_ != config_.bin_construct_sample_cnt) {
939
940
      Log::Fatal("Dataset was constructed with parameter bin_construct_sample_cnt=%d. It cannot be changed to %d when loading from binary file.",
                 dataset->bin_construct_sample_cnt_, config_.bin_construct_sample_cnt);
941
942
943
944
    }
    if ((dataset->max_bin_by_feature_.size() != config_.max_bin_by_feature.size()) ||
        !std::equal(dataset->max_bin_by_feature_.begin(), dataset->max_bin_by_feature_.end(),
            config_.max_bin_by_feature.begin())) {
945
      Log::Fatal("Parameter max_bin_by_feature cannot be changed when loading from binary file.");
946
947
    }

948
    if (config_.label_column != "") {
949
      Log::Warning("Parameter label_column works only in case of loading data directly from text file. It will be ignored when loading from binary file.");
950
951
    }
    if (config_.weight_column != "") {
952
      Log::Warning("Parameter weight_column works only in case of loading data directly from text file. It will be ignored when loading from binary file.");
953
954
    }
    if (config_.group_column != "") {
955
      Log::Warning("Parameter group_column works only in case of loading data directly from text file. It will be ignored when loading from binary file.");
956
957
    }
    if (config_.ignore_column != "") {
958
      Log::Warning("Parameter ignore_column works only in case of loading data directly from text file. It will be ignored when loading from binary file.");
959
    }
960
    if (config_.two_round) {
961
      Log::Warning("Parameter two_round works only in case of loading data directly from text file. It will be ignored when loading from binary file.");
962
963
    }
    if (config_.header) {
964
      Log::Warning("Parameter header works only in case of loading data directly from text file. It will be ignored when loading from binary file.");
965
    }
966
  }
Guolin Ke's avatar
Guolin Ke committed
967
968
969
}

std::vector<std::string> DatasetLoader::LoadTextDataToMemory(const char* filename, const Metadata& metadata,
970
971
                                                             int rank, int num_machines, int* num_global_data,
                                                             std::vector<data_size_t>* used_data_indices) {
972
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
973
  used_data_indices->clear();
Guolin Ke's avatar
Guolin Ke committed
974
  if (num_machines == 1 || config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
975
976
977
978
979
980
981
982
983
    // read all lines
    *num_global_data = text_reader.ReadAllLines();
  } else {  // need partition data
            // get query data
    const data_size_t* query_boundaries = metadata.query_boundaries();

    if (query_boundaries == nullptr) {
      // if not contain query data, minimal sample unit is one record
      *num_global_data = text_reader.ReadAndFilterLines([this, rank, num_machines](data_size_t) {
Guolin Ke's avatar
Guolin Ke committed
984
        if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
985
986
987
988
989
990
991
992
993
994
995
996
997
998
          return true;
        } else {
          return false;
        }
      }, used_data_indices);
    } else {
      // if contain query data, minimal sample unit is one query
      data_size_t num_queries = metadata.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      *num_global_data = text_reader.ReadAndFilterLines(
        [this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries]
      (data_size_t line_idx) {
        if (qid >= num_queries) {
999
1000
          Log::Fatal("Current query exceeds the range of the query file,\n"
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
1001
1002
1003
1004
        }
        if (line_idx >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
1005
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
            is_query_used = true;
          }
          ++qid;
        }
        return is_query_used;
      }, used_data_indices);
    }
  }
  return std::move(text_reader.Lines());
}

std::vector<std::string> DatasetLoader::SampleTextDataFromMemory(const std::vector<std::string>& data) {
Guolin Ke's avatar
Guolin Ke committed
1018
  int sample_cnt = config_.bin_construct_sample_cnt;
1019
1020
  if (static_cast<size_t>(sample_cnt) > data.size()) {
    sample_cnt = static_cast<int>(data.size());
1021
  }
1022
  auto sample_indices = random_.Sample(static_cast<int>(data.size()), sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
1023
  std::vector<std::string> out(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
1024
1025
  for (size_t i = 0; i < sample_indices.size(); ++i) {
    const size_t idx = sample_indices[i];
Guolin Ke's avatar
Guolin Ke committed
1026
    out[i] = data[idx];
Guolin Ke's avatar
Guolin Ke committed
1027
1028
1029
1030
  }
  return out;
}

1031
1032
1033
std::vector<std::string> DatasetLoader::SampleTextDataFromFile(const char* filename, const Metadata& metadata,
                                                               int rank, int num_machines, int* num_global_data,
                                                               std::vector<data_size_t>* used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
1034
  const data_size_t sample_cnt = static_cast<data_size_t>(config_.bin_construct_sample_cnt);
1035
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
1036
  std::vector<std::string> out_data;
Guolin Ke's avatar
Guolin Ke committed
1037
  if (num_machines == 1 || config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
1038
    *num_global_data = static_cast<data_size_t>(text_reader.SampleFromFile(&random_, sample_cnt, &out_data));
Guolin Ke's avatar
Guolin Ke committed
1039
1040
1041
1042
1043
1044
1045
  } else {  // need partition data
            // get query data
    const data_size_t* query_boundaries = metadata.query_boundaries();
    if (query_boundaries == nullptr) {
      // if not contain query file, minimal sample unit is one record
      *num_global_data = text_reader.SampleAndFilterFromFile([this, rank, num_machines]
      (data_size_t) {
Guolin Ke's avatar
Guolin Ke committed
1046
        if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
1047
1048
1049
1050
          return true;
        } else {
          return false;
        }
Guolin Ke's avatar
Guolin Ke committed
1051
      }, used_data_indices, &random_, sample_cnt, &out_data);
Guolin Ke's avatar
Guolin Ke committed
1052
1053
1054
1055
1056
1057
1058
1059
1060
    } else {
      // if contain query file, minimal sample unit is one query
      data_size_t num_queries = metadata.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      *num_global_data = text_reader.SampleAndFilterFromFile(
        [this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries]
      (data_size_t line_idx) {
        if (qid >= num_queries) {
1061
1062
          Log::Fatal("Query id exceeds the range of the query file, "
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
1063
1064
1065
1066
        }
        if (line_idx >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
1067
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
1068
1069
1070
1071
1072
            is_query_used = true;
          }
          ++qid;
        }
        return is_query_used;
Guolin Ke's avatar
Guolin Ke committed
1073
      }, used_data_indices, &random_, sample_cnt, &out_data);
Guolin Ke's avatar
Guolin Ke committed
1074
1075
1076
1077
1078
    }
  }
  return out_data;
}

1079
1080
1081
void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
                                                    const std::vector<std::string>& sample_data,
                                                    const Parser* parser, Dataset* dataset) {
1082
  auto t1 = std::chrono::high_resolution_clock::now();
Guolin Ke's avatar
Guolin Ke committed
1083
  std::vector<std::vector<double>> sample_values;
Guolin Ke's avatar
Guolin Ke committed
1084
  std::vector<std::vector<int>> sample_indices;
Guolin Ke's avatar
Guolin Ke committed
1085
1086
  std::vector<std::pair<int, double>> oneline_features;
  double label;
Guolin Ke's avatar
Guolin Ke committed
1087
  for (int i = 0; i < static_cast<int>(sample_data.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
1088
1089
1090
1091
    oneline_features.clear();
    // parse features
    parser->ParseOneLine(sample_data[i].c_str(), &oneline_features, &label);
    for (std::pair<int, double>& inner_data : oneline_features) {
1092
      if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
Guolin Ke's avatar
Guolin Ke committed
1093
1094
        sample_values.resize(inner_data.first + 1);
        sample_indices.resize(inner_data.first + 1);
1095
      }
Guolin Ke's avatar
Guolin Ke committed
1096
      if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
Guolin Ke's avatar
Guolin Ke committed
1097
1098
        sample_values[inner_data.first].emplace_back(inner_data.second);
        sample_indices[inner_data.first].emplace_back(i);
Guolin Ke's avatar
Guolin Ke committed
1099
1100
1101
1102
      }
    }
  }

Guolin Ke's avatar
Guolin Ke committed
1103
  dataset->feature_groups_.clear();
1104
1105
1106
1107
1108
  dataset->num_total_features_ = std::max(static_cast<int>(sample_values.size()), parser->NumFeatures());
  if (num_machines > 1) {
    dataset->num_total_features_ = Network::GlobalSyncUpByMax(dataset->num_total_features_);
  }
  if (!feature_names_.empty()) {
1109
    CHECK_EQ(dataset->num_total_features_, static_cast<int>(feature_names_.size()));
1110
  }
Guolin Ke's avatar
Guolin Ke committed
1111

Belinda Trotta's avatar
Belinda Trotta committed
1112
  if (!config_.max_bin_by_feature.empty()) {
1113
1114
    CHECK_EQ(static_cast<size_t>(dataset->num_total_features_), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
Belinda Trotta's avatar
Belinda Trotta committed
1115
1116
  }

1117
1118
  // get forced split
  std::string forced_bins_path = config_.forcedbins_filename;
1119
1120
  std::vector<std::vector<double>> forced_bin_bounds = DatasetLoader::GetForcedBins(forced_bins_path,
                                                                                    dataset->num_total_features_,
1121
1122
                                                                                    categorical_features_);

Guolin Ke's avatar
Guolin Ke committed
1123
  // check the range of label_idx, weight_idx and group_idx
1124
1125
1126
1127
1128
  // skip label check if user input parser config file,
  // because label id is got from raw features while dataset features are consistent with customized parser.
  if (dataset->parser_config_str_.empty()) {
    CHECK(label_idx_ >= 0 && label_idx_ <= dataset->num_total_features_);
  }
Guolin Ke's avatar
Guolin Ke committed
1129
1130
1131
1132
  CHECK(weight_idx_ < 0 || weight_idx_ < dataset->num_total_features_);
  CHECK(group_idx_ < 0 || group_idx_ < dataset->num_total_features_);

  // fill feature_names_ if not header
Guolin Ke's avatar
Guolin Ke committed
1133
  if (feature_names_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1134
1135
1136
1137
1138
1139
    for (int i = 0; i < dataset->num_total_features_; ++i) {
      std::stringstream str_buf;
      str_buf << "Column_" << i;
      feature_names_.push_back(str_buf.str());
    }
  }
1140
  dataset->set_feature_names(feature_names_);
Guolin Ke's avatar
Guolin Ke committed
1141
  std::vector<std::unique_ptr<BinMapper>> bin_mappers(dataset->num_total_features_);
Guolin Ke's avatar
Guolin Ke committed
1142
  const data_size_t filter_cnt = static_cast<data_size_t>(
Guolin Ke's avatar
Guolin Ke committed
1143
    static_cast<double>(config_.min_data_in_leaf* sample_data.size()) / dataset->num_data_);
Guolin Ke's avatar
Guolin Ke committed
1144
1145
1146
  // start find bins
  if (num_machines == 1) {
    // if only one machine, find bin locally
1147
    OMP_INIT_EX();
1148
    #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(guided)
Guolin Ke's avatar
Guolin Ke committed
1149
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
1150
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1151
      if (ignore_features_.count(i) > 0) {
Guolin Ke's avatar
Guolin Ke committed
1152
        bin_mappers[i] = nullptr;
Guolin Ke's avatar
Guolin Ke committed
1153
1154
        continue;
      }
1155
1156
1157
1158
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(i)) {
        bin_type = BinType::CategoricalBin;
      }
Guolin Ke's avatar
Guolin Ke committed
1159
      bin_mappers[i].reset(new BinMapper());
Belinda Trotta's avatar
Belinda Trotta committed
1160
1161
      if (config_.max_bin_by_feature.empty()) {
        bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
1162
                                sample_data.size(), config_.max_bin, config_.min_data_in_bin,
1163
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
1164
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
1165
1166
      } else {
        bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
1167
                                sample_data.size(), config_.max_bin_by_feature[i],
1168
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
1169
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
1170
      }
1171
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1172
    }
1173
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1174
1175
  } else {
    // start and len will store the process feature indices for different machines
1176
    // machine i will find bins for features in [ start[i], start[i] + len[i] )
Guolin Ke's avatar
Guolin Ke committed
1177
1178
    std::vector<int> start(num_machines);
    std::vector<int> len(num_machines);
1179
    int step = (dataset->num_total_features_ + num_machines - 1) / num_machines;
1180
1181
1182
    if (step < 1) {
      step = 1;
    }
Guolin Ke's avatar
Guolin Ke committed
1183
1184
1185

    start[0] = 0;
    for (int i = 0; i < num_machines - 1; ++i) {
1186
      len[i] = std::min(step, dataset->num_total_features_ - start[i]);
Guolin Ke's avatar
Guolin Ke committed
1187
1188
      start[i + 1] = start[i] + len[i];
    }
1189
    len[num_machines - 1] = dataset->num_total_features_ - start[num_machines - 1];
1190
    OMP_INIT_EX();
1191
    #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(guided)
1192
    for (int i = 0; i < len[rank]; ++i) {
1193
      OMP_LOOP_EX_BEGIN();
1194
1195
1196
1197
1198
1199
1200
1201
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(start[rank] + i)) {
        bin_type = BinType::CategoricalBin;
      }
      bin_mappers[i].reset(new BinMapper());
Nikita Titov's avatar
Nikita Titov committed
1202
      if (static_cast<int>(sample_values.size()) <= start[rank] + i) {
1203
1204
        continue;
      }
Belinda Trotta's avatar
Belinda Trotta committed
1205
      if (config_.max_bin_by_feature.empty()) {
1206
        bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
Belinda Trotta's avatar
Belinda Trotta committed
1207
                                static_cast<int>(sample_values[start[rank] + i].size()),
1208
                                sample_data.size(), config_.max_bin, config_.min_data_in_bin,
1209
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
1210
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
1211
      } else {
1212
        bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
Belinda Trotta's avatar
Belinda Trotta committed
1213
                                static_cast<int>(sample_values[start[rank] + i].size()),
1214
                                sample_data.size(), config_.max_bin_by_feature[i],
1215
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type,
1216
                                config_.use_missing, config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
1217
      }
1218
      OMP_LOOP_EX_END();
1219
    }
1220
    OMP_THROW_EX();
1221
    comm_size_t self_buf_size = 0;
Guolin Ke's avatar
Guolin Ke committed
1222
    for (int i = 0; i < len[rank]; ++i) {
1223
1224
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
Guolin Ke's avatar
Guolin Ke committed
1225
      }
1226
      self_buf_size += static_cast<comm_size_t>(bin_mappers[i]->SizesInByte());
Guolin Ke's avatar
Guolin Ke committed
1227
    }
1228
1229
    std::vector<char> input_buffer(self_buf_size);
    auto cp_ptr = input_buffer.data();
Guolin Ke's avatar
Guolin Ke committed
1230
    for (int i = 0; i < len[rank]; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1231
1232
1233
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
1234
1235
      bin_mappers[i]->CopyTo(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
1236
1237
      // free
      bin_mappers[i].reset(nullptr);
Guolin Ke's avatar
Guolin Ke committed
1238
    }
1239
1240
1241
1242
    std::vector<comm_size_t> size_len = Network::GlobalArray(self_buf_size);
    std::vector<comm_size_t> size_start(num_machines, 0);
    for (int i = 1; i < num_machines; ++i) {
      size_start[i] = size_start[i - 1] + size_len[i - 1];
Guolin Ke's avatar
Guolin Ke committed
1243
    }
1244
1245
    comm_size_t total_buffer_size = size_start[num_machines - 1] + size_len[num_machines - 1];
    std::vector<char> output_buffer(total_buffer_size);
Guolin Ke's avatar
Guolin Ke committed
1246
    // gather global feature bin mappers
1247
1248
    Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), total_buffer_size);
    cp_ptr = output_buffer.data();
Guolin Ke's avatar
Guolin Ke committed
1249
    // restore features bins from buffer
1250
    for (int i = 0; i < dataset->num_total_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1251
      if (ignore_features_.count(i) > 0) {
Guolin Ke's avatar
Guolin Ke committed
1252
        bin_mappers[i] = nullptr;
Guolin Ke's avatar
Guolin Ke committed
1253
1254
        continue;
      }
Guolin Ke's avatar
Guolin Ke committed
1255
      bin_mappers[i].reset(new BinMapper());
1256
1257
      bin_mappers[i]->CopyFrom(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
Guolin Ke's avatar
Guolin Ke committed
1258
1259
    }
  }
1260
  CheckCategoricalFeatureNumBin(bin_mappers, config_.max_bin, config_.max_bin_by_feature);
1261
  dataset->Construct(&bin_mappers, dataset->num_total_features_, forced_bin_bounds, Common::Vector2Ptr<int>(&sample_indices).data(),
Guolin Ke's avatar
Guolin Ke committed
1262
                     Common::Vector2Ptr<double>(&sample_values).data(),
1263
                     Common::VectorSize<int>(sample_indices).data(), static_cast<int>(sample_indices.size()), sample_data.size(), config_);
1264
  if (dataset->has_raw()) {
1265
    dataset->ResizeRaw(static_cast<int>(sample_data.size()));
1266
  }
1267
1268
1269
1270

  auto t2 = std::chrono::high_resolution_clock::now();
  Log::Info("Construct bin mappers from text data time %.2f seconds",
            std::chrono::duration<double, std::milli>(t2 - t1) * 1e-3);
Guolin Ke's avatar
Guolin Ke committed
1271
1272
1273
}

/*! \brief Extract local features from memory */
Guolin Ke's avatar
Guolin Ke committed
1274
void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>* text_data, const Parser* parser, Dataset* dataset) {
Guolin Ke's avatar
Guolin Ke committed
1275
1276
  std::vector<std::pair<int, double>> oneline_features;
  double tmp_label = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
1277
  auto& ref_text_data = *text_data;
1278
  std::vector<float> feature_row(dataset->num_features_);
1279
  if (!predict_fun_) {
1280
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1281
    // if doesn't need to prediction with initial model
1282
    #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) private(oneline_features) firstprivate(tmp_label, feature_row)
Guolin Ke's avatar
Guolin Ke committed
1283
    for (data_size_t i = 0; i < dataset->num_data_; ++i) {
1284
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1285
1286
1287
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
Guolin Ke's avatar
Guolin Ke committed
1288
      parser->ParseOneLine(ref_text_data[i].c_str(), &oneline_features, &tmp_label);
Guolin Ke's avatar
Guolin Ke committed
1289
      // set label
1290
      dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1291
      // free processed line:
Guolin Ke's avatar
Guolin Ke committed
1292
      ref_text_data[i].clear();
Guolin Ke's avatar
Guolin Ke committed
1293
1294
      // shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
      // text_reader_->Lines()[i].shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
1295
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1296
1297
      // push data
      for (auto& inner_data : oneline_features) {
1298
1299
1300
        if (inner_data.first >= dataset->num_total_features_) {
          continue;
        }
Guolin Ke's avatar
Guolin Ke committed
1301
1302
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1303
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1304
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1305
1306
1307
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
          dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
1308
          if (dataset->has_raw()) {
1309
            feature_row[feature_idx] = static_cast<float>(inner_data.second);
1310
          }
Guolin Ke's avatar
Guolin Ke committed
1311
1312
        } else {
          if (inner_data.first == weight_idx_) {
1313
            dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1314
1315
1316
1317
1318
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
1319
1320
1321
1322
1323
1324
1325
1326
      if (dataset->has_raw()) {
        for (size_t j = 0; j < feature_row.size(); ++j) {
          int feat_ind = dataset->numeric_feature_map_[j];
          if (feat_ind >= 0) {
            dataset->raw_data_[feat_ind][i] = feature_row[j];
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1327
      dataset->FinishOneRow(tid, i, is_feature_added);
1328
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1329
    }
1330
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1331
  } else {
1332
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1333
    // if need to prediction with initial model
1334
    std::vector<double> init_score(static_cast<size_t>(dataset->num_data_) * num_class_);
1335
    #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) private(oneline_features) firstprivate(tmp_label, feature_row)
Guolin Ke's avatar
Guolin Ke committed
1336
    for (data_size_t i = 0; i < dataset->num_data_; ++i) {
1337
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1338
1339
1340
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
Guolin Ke's avatar
Guolin Ke committed
1341
      parser->ParseOneLine(ref_text_data[i].c_str(), &oneline_features, &tmp_label);
Guolin Ke's avatar
Guolin Ke committed
1342
      // set initial score
Guolin Ke's avatar
Guolin Ke committed
1343
1344
      std::vector<double> oneline_init_score(num_class_);
      predict_fun_(oneline_features, oneline_init_score.data());
1345
      for (int k = 0; k < num_class_; ++k) {
1346
        init_score[k * dataset->num_data_ + i] = static_cast<double>(oneline_init_score[k]);
Guolin Ke's avatar
Guolin Ke committed
1347
1348
      }
      // set label
1349
      dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1350
      // free processed line:
1351
      ref_text_data[i].clear();
Andrew Ziem's avatar
Andrew Ziem committed
1352
      // shrink_to_fit will be very slow in Linux, and seems not free memory, disable for now
Guolin Ke's avatar
Guolin Ke committed
1353
1354
      // text_reader_->Lines()[i].shrink_to_fit();
      // push data
Guolin Ke's avatar
Guolin Ke committed
1355
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1356
      for (auto& inner_data : oneline_features) {
1357
1358
1359
        if (inner_data.first >= dataset->num_total_features_) {
          continue;
        }
Guolin Ke's avatar
Guolin Ke committed
1360
1361
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1362
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1363
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1364
1365
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
1366
          dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
1367
          if (dataset->has_raw()) {
1368
            feature_row[feature_idx] = static_cast<float>(inner_data.second);
1369
          }
Guolin Ke's avatar
Guolin Ke committed
1370
1371
        } else {
          if (inner_data.first == weight_idx_) {
1372
            dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1373
1374
1375
1376
1377
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1378
      dataset->FinishOneRow(tid, i, is_feature_added);
1379
1380
1381
1382
1383
1384
1385
1386
      if (dataset->has_raw()) {
        for (size_t j = 0; j < feature_row.size(); ++j) {
          int feat_ind = dataset->numeric_feature_map_[j];
          if (feat_ind >= 0) {
            dataset->raw_data_[feat_ind][i] = feature_row[j];
          }
        }
      }
1387
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1388
    }
1389
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1390
    // metadata_ will manage space of init_score
1391
    dataset->metadata_.SetInitScore(init_score.data(), dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1392
  }
Guolin Ke's avatar
Guolin Ke committed
1393
  dataset->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1394
  // text data can be free after loaded feature values
Guolin Ke's avatar
Guolin Ke committed
1395
  text_data->clear();
Guolin Ke's avatar
Guolin Ke committed
1396
1397
1398
}

/*! \brief Extract local features from file */
1399
1400
void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser* parser,
                                            const std::vector<data_size_t>& used_data_indices, Dataset* dataset) {
1401
  std::vector<double> init_score;
1402
  if (predict_fun_) {
1403
    init_score = std::vector<double>(static_cast<size_t>(dataset->num_data_) * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1404
1405
1406
1407
1408
1409
  }
  std::function<void(data_size_t, const std::vector<std::string>&)> process_fun =
    [this, &init_score, &parser, &dataset]
  (data_size_t start_idx, const std::vector<std::string>& lines) {
    std::vector<std::pair<int, double>> oneline_features;
    double tmp_label = 0.0f;
1410
    std::vector<float> feature_row(dataset->num_features_);
1411
    OMP_INIT_EX();
1412
    #pragma omp parallel for num_threads(OMP_NUM_THREADS()) schedule(static) private(oneline_features) firstprivate(tmp_label, feature_row)
Guolin Ke's avatar
Guolin Ke committed
1413
    for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) {
1414
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1415
1416
1417
1418
1419
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
      parser->ParseOneLine(lines[i].c_str(), &oneline_features, &tmp_label);
      // set initial score
Guolin Ke's avatar
Guolin Ke committed
1420
      if (!init_score.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1421
1422
        std::vector<double> oneline_init_score(num_class_);
        predict_fun_(oneline_features, oneline_init_score.data());
1423
        for (int k = 0; k < num_class_; ++k) {
1424
          init_score[k * dataset->num_data_ + start_idx + i] = static_cast<double>(oneline_init_score[k]);
Guolin Ke's avatar
Guolin Ke committed
1425
1426
1427
        }
      }
      // set label
1428
      dataset->metadata_.SetLabelAt(start_idx + i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1429
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1430
1431
      // push data
      for (auto& inner_data : oneline_features) {
1432
1433
1434
        if (inner_data.first >= dataset->num_total_features_) {
          continue;
        }
Guolin Ke's avatar
Guolin Ke committed
1435
1436
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1437
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1438
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1439
1440
1441
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
          dataset->feature_groups_[group]->PushData(tid, sub_feature, start_idx + i, inner_data.second);
1442
          if (dataset->has_raw()) {
1443
            feature_row[feature_idx] = static_cast<float>(inner_data.second);
1444
          }
Guolin Ke's avatar
Guolin Ke committed
1445
1446
        } else {
          if (inner_data.first == weight_idx_) {
1447
            dataset->metadata_.SetWeightAt(start_idx + i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1448
1449
1450
1451
1452
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(start_idx + i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
1453
1454
1455
1456
1457
1458
1459
1460
      if (dataset->has_raw()) {
        for (size_t j = 0; j < feature_row.size(); ++j) {
          int feat_ind = dataset->numeric_feature_map_[j];
          if (feat_ind >= 0) {
            dataset->raw_data_[feat_ind][i] = feature_row[j];
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1461
      dataset->FinishOneRow(tid, i, is_feature_added);
1462
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1463
    }
1464
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1465
  };
1466
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
1467
  if (!used_data_indices.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1468
1469
1470
1471
1472
1473
1474
1475
    // only need part of data
    text_reader.ReadPartAndProcessParallel(used_data_indices, process_fun);
  } else {
    // need full data
    text_reader.ReadAllAndProcessParallel(process_fun);
  }

  // metadata_ will manage space of init_score
Guolin Ke's avatar
Guolin Ke committed
1476
  if (!init_score.empty()) {
1477
    dataset->metadata_.SetInitScore(init_score.data(), dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1478
  }
Guolin Ke's avatar
Guolin Ke committed
1479
  dataset->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1480
1481
1482
}

/*! \brief Check can load from binary file */
1483
std::string DatasetLoader::CheckCanLoadFromBin(const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
1484
1485
1486
  std::string bin_filename(filename);
  bin_filename.append(".bin");

1487
  auto reader = VirtualFileReader::Make(bin_filename.c_str());
Guolin Ke's avatar
Guolin Ke committed
1488

1489
  if (!reader->Init()) {
1490
    bin_filename = std::string(filename);
1491
1492
    reader = VirtualFileReader::Make(bin_filename.c_str());
    if (!reader->Init()) {
1493
      Log::Fatal("Cannot open data file %s", bin_filename.c_str());
1494
    }
1495
  }
1496
1497
1498
1499
1500

  size_t buffer_size = 256;
  auto buffer = std::vector<char>(buffer_size);
  // read size of token
  size_t size_of_token = std::strlen(Dataset::binary_file_token);
1501
  size_t read_cnt = reader->Read(buffer.data(), size_of_token);
1502
1503
  if (read_cnt == size_of_token
      && std::string(buffer.data()) == std::string(Dataset::binary_file_token)) {
1504
    return bin_filename;
Guolin Ke's avatar
Guolin Ke committed
1505
  } else {
1506
    return std::string();
Guolin Ke's avatar
Guolin Ke committed
1507
1508
1509
  }
}

1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
std::vector<std::vector<double>> DatasetLoader::GetForcedBins(std::string forced_bins_path, int num_total_features,
                                                              const std::unordered_set<int>& categorical_features) {
  std::vector<std::vector<double>> forced_bins(num_total_features, std::vector<double>());
  if (forced_bins_path != "") {
    std::ifstream forced_bins_stream(forced_bins_path.c_str());
    if (forced_bins_stream.fail()) {
      Log::Warning("Could not open %s. Will ignore.", forced_bins_path.c_str());
    } else {
      std::stringstream buffer;
      buffer << forced_bins_stream.rdbuf();
      std::string err;
Guolin Ke's avatar
Guolin Ke committed
1521
      Json forced_bins_json = Json::parse(buffer.str(), &err);
1522
1523
1524
1525
      CHECK(forced_bins_json.is_array());
      std::vector<Json> forced_bins_arr = forced_bins_json.array_items();
      for (size_t i = 0; i < forced_bins_arr.size(); ++i) {
        int feature_num = forced_bins_arr[i]["feature"].int_value();
Nikita Titov's avatar
Nikita Titov committed
1526
        CHECK_LT(feature_num, num_total_features);
1527
        if (categorical_features.count(feature_num)) {
1528
          Log::Warning("Feature %d is categorical. Will ignore forced bins for this feature.", feature_num);
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
        } else {
          std::vector<Json> bounds_arr = forced_bins_arr[i]["bin_upper_bound"].array_items();
          for (size_t j = 0; j < bounds_arr.size(); ++j) {
            forced_bins[feature_num].push_back(bounds_arr[j].number_value());
          }
        }
      }
      // remove duplicates
      for (int i = 0; i < num_total_features; ++i) {
        auto new_end = std::unique(forced_bins[i].begin(), forced_bins[i].end());
        forced_bins[i].erase(new_end, forced_bins[i].end());
      }
    }
  }
  return forced_bins;
}

1546
1547
1548
1549
1550
1551
1552
void DatasetLoader::CheckCategoricalFeatureNumBin(
  const std::vector<std::unique_ptr<BinMapper>>& bin_mappers,
  const int max_bin, const std::vector<int>& max_bin_by_feature) const {
  bool need_warning = false;
  if (bin_mappers.size() < 1024) {
    for (size_t i = 0; i < bin_mappers.size(); ++i) {
      const int max_bin_for_this_feature = max_bin_by_feature.empty() ? max_bin : max_bin_by_feature[i];
1553
      if (bin_mappers[i] != nullptr && bin_mappers[i]->bin_type() == BinType::CategoricalBin && bin_mappers[i]->num_bin() > max_bin_for_this_feature) {
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
        need_warning = true;
        break;
      }
    }
  } else {
    const int num_threads = OMP_NUM_THREADS();
    std::vector<bool> thread_need_warning(num_threads, false);
    Threading::For<size_t>(0, bin_mappers.size(), 1,
      [&bin_mappers, &thread_need_warning, &max_bin_by_feature, max_bin] (int thread_index, size_t start, size_t end) {
        for (size_t i = start; i < end; ++i) {
          thread_need_warning[thread_index] = false;
          const int max_bin_for_this_feature = max_bin_by_feature.empty() ? max_bin : max_bin_by_feature[i];
1566
          if (bin_mappers[i] != nullptr && bin_mappers[i]->bin_type() == BinType::CategoricalBin && bin_mappers[i]->num_bin() > max_bin_for_this_feature) {
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
            thread_need_warning[thread_index] = true;
            break;
          }
        }
      });
    for (int thread_index = 0; thread_index < num_threads; ++thread_index) {
      if (thread_need_warning[thread_index]) {
        need_warning = true;
        break;
      }
    }
  }

  if (need_warning) {
    Log::Warning("Categorical features with more bins than the configured maximum bin number found.");
    Log::Warning("For categorical features, max_bin and max_bin_by_feature may be ignored with a large number of categories.");
  }
}

1586
}  // namespace LightGBM