dataset_loader.cpp 58.6 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
#include <LightGBM/dataset_loader.h>

Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/network.h>
8
#include <LightGBM/utils/array_args.h>
9
#include <LightGBM/utils/json11.h>
10
11
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
12

13
14
#include <fstream>

Guolin Ke's avatar
Guolin Ke committed
15
16
namespace LightGBM {

17
18
using json11::Json;

Guolin Ke's avatar
Guolin Ke committed
19
20
DatasetLoader::DatasetLoader(const Config& io_config, const PredictFunction& predict_fun, int num_class, const char* filename)
  :config_(io_config), random_(config_.data_random_seed), predict_fun_(predict_fun), num_class_(num_class) {
Guolin Ke's avatar
Guolin Ke committed
21
22
23
24
  label_idx_ = 0;
  weight_idx_ = NO_SPECIFIC;
  group_idx_ = NO_SPECIFIC;
  SetHeader(filename);
25
26
27
28
  store_raw_ = false;
  if (io_config.linear_tree) {
    store_raw_ = true;
  }
Guolin Ke's avatar
Guolin Ke committed
29
30
31
32
33
}

DatasetLoader::~DatasetLoader() {
}

Guolin Ke's avatar
Guolin Ke committed
34
void DatasetLoader::SetHeader(const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
35
  std::unordered_map<std::string, int> name2idx;
Guolin Ke's avatar
Guolin Ke committed
36
37
  std::string name_prefix("name:");
  if (filename != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
38
    TextReader<data_size_t> text_reader(filename, config_.header);
Guolin Ke's avatar
Guolin Ke committed
39

Guolin Ke's avatar
Guolin Ke committed
40
    // get column names
Guolin Ke's avatar
Guolin Ke committed
41
    if (config_.header) {
Guolin Ke's avatar
Guolin Ke committed
42
      std::string first_line = text_reader.first_line();
43
      feature_names_ = Common::Split(first_line.c_str(), "\t,");
Guolin Ke's avatar
Guolin Ke committed
44
45
    }

Guolin Ke's avatar
Guolin Ke committed
46
    // load label idx first
Guolin Ke's avatar
Guolin Ke committed
47
48
49
    if (config_.label_column.size() > 0) {
      if (Common::StartsWith(config_.label_column, name_prefix)) {
        std::string name = config_.label_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
50
51
52
53
54
55
56
57
58
59
        label_idx_ = -1;
        for (int i = 0; i < static_cast<int>(feature_names_.size()); ++i) {
          if (name == feature_names_[i]) {
            label_idx_ = i;
            break;
          }
        }
        if (label_idx_ >= 0) {
          Log::Info("Using column %s as label", name.c_str());
        } else {
60
61
          Log::Fatal("Could not find label column %s in data file \n"
                     "or data file doesn't contain header", name.c_str());
Guolin Ke's avatar
Guolin Ke committed
62
        }
Guolin Ke's avatar
Guolin Ke committed
63
      } else {
Guolin Ke's avatar
Guolin Ke committed
64
        if (!Common::AtoiAndCheck(config_.label_column.c_str(), &label_idx_)) {
65
66
67
          Log::Fatal("label_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
68
69
        }
        Log::Info("Using column number %d as label", label_idx_);
Guolin Ke's avatar
Guolin Ke committed
70
71
      }
    }
Guolin Ke's avatar
Guolin Ke committed
72

Guolin Ke's avatar
Guolin Ke committed
73
    if (!feature_names_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
74
75
76
77
      // erase label column name
      feature_names_.erase(feature_names_.begin() + label_idx_);
      for (size_t i = 0; i < feature_names_.size(); ++i) {
        name2idx[feature_names_[i]] = static_cast<int>(i);
Guolin Ke's avatar
Guolin Ke committed
78
      }
Guolin Ke's avatar
Guolin Ke committed
79
80
81
    }

    // load ignore columns
Guolin Ke's avatar
Guolin Ke committed
82
83
84
    if (config_.ignore_column.size() > 0) {
      if (Common::StartsWith(config_.ignore_column, name_prefix)) {
        std::string names = config_.ignore_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
85
86
87
88
89
90
91
92
93
        for (auto name : Common::Split(names.c_str(), ',')) {
          if (name2idx.count(name) > 0) {
            int tmp = name2idx[name];
            ignore_features_.emplace(tmp);
          } else {
            Log::Fatal("Could not find ignore column %s in data file", name.c_str());
          }
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
94
        for (auto token : Common::Split(config_.ignore_column.c_str(), ',')) {
Guolin Ke's avatar
Guolin Ke committed
95
96
          int tmp = 0;
          if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
97
98
99
            Log::Fatal("ignore_column is not a number,\n"
                       "if you want to use a column name,\n"
                       "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
100
101
          }
          ignore_features_.emplace(tmp);
Guolin Ke's avatar
Guolin Ke committed
102
103
104
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
105
    // load weight idx
Guolin Ke's avatar
Guolin Ke committed
106
107
108
    if (config_.weight_column.size() > 0) {
      if (Common::StartsWith(config_.weight_column, name_prefix)) {
        std::string name = config_.weight_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
109
110
111
112
113
114
        if (name2idx.count(name) > 0) {
          weight_idx_ = name2idx[name];
          Log::Info("Using column %s as weight", name.c_str());
        } else {
          Log::Fatal("Could not find weight column %s in data file", name.c_str());
        }
Guolin Ke's avatar
Guolin Ke committed
115
      } else {
Guolin Ke's avatar
Guolin Ke committed
116
        if (!Common::AtoiAndCheck(config_.weight_column.c_str(), &weight_idx_)) {
117
118
119
          Log::Fatal("weight_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
120
121
        }
        Log::Info("Using column number %d as weight", weight_idx_);
Guolin Ke's avatar
Guolin Ke committed
122
      }
Guolin Ke's avatar
Guolin Ke committed
123
      ignore_features_.emplace(weight_idx_);
Guolin Ke's avatar
Guolin Ke committed
124
    }
Guolin Ke's avatar
Guolin Ke committed
125
    // load group idx
Guolin Ke's avatar
Guolin Ke committed
126
127
128
    if (config_.group_column.size() > 0) {
      if (Common::StartsWith(config_.group_column, name_prefix)) {
        std::string name = config_.group_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
129
130
131
132
133
134
135
        if (name2idx.count(name) > 0) {
          group_idx_ = name2idx[name];
          Log::Info("Using column %s as group/query id", name.c_str());
        } else {
          Log::Fatal("Could not find group/query column %s in data file", name.c_str());
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
136
        if (!Common::AtoiAndCheck(config_.group_column.c_str(), &group_idx_)) {
137
138
139
          Log::Fatal("group_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
140
141
142
143
        }
        Log::Info("Using column number %d as group/query id", group_idx_);
      }
      ignore_features_.emplace(group_idx_);
Guolin Ke's avatar
Guolin Ke committed
144
145
    }
  }
Guolin Ke's avatar
Guolin Ke committed
146
147
148
  if (config_.categorical_feature.size() > 0) {
    if (Common::StartsWith(config_.categorical_feature, name_prefix)) {
      std::string names = config_.categorical_feature.substr(name_prefix.size());
149
150
151
152
153
      for (auto name : Common::Split(names.c_str(), ',')) {
        if (name2idx.count(name) > 0) {
          int tmp = name2idx[name];
          categorical_features_.emplace(tmp);
        } else {
Guolin Ke's avatar
Guolin Ke committed
154
          Log::Fatal("Could not find categorical_feature %s in data file", name.c_str());
155
156
157
        }
      }
    } else {
Guolin Ke's avatar
Guolin Ke committed
158
      for (auto token : Common::Split(config_.categorical_feature.c_str(), ',')) {
159
160
        int tmp = 0;
        if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
Guolin Ke's avatar
Guolin Ke committed
161
          Log::Fatal("categorical_feature is not a number,\n"
162
163
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
164
165
166
167
168
        }
        categorical_features_.emplace(tmp);
      }
    }
  }
Guolin Ke's avatar
Guolin Ke committed
169
170
}

171
172
173
174
175
176
177
178
179
180
void CheckSampleSize(size_t sample_cnt, size_t num_data) {
  if (static_cast<double>(sample_cnt) / num_data < 0.2f &&
      sample_cnt < 100000) {
    Log::Warning(
        "Using too small ``bin_construct_sample_cnt`` may encounter "
        "unexpected "
        "errors and poor accuracy.");
  }
}

181
Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_machines) {
Guolin Ke's avatar
Guolin Ke committed
182
  // don't support query id in data file when training in parallel
Guolin Ke's avatar
Guolin Ke committed
183
  if (num_machines > 1 && !config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
184
    if (group_idx_ > 0) {
185
186
      Log::Fatal("Using a query id without pre-partitioning the data file is not supported for parallel training.\n"
                 "Please use an additional query file or pre-partition the data");
Guolin Ke's avatar
Guolin Ke committed
187
188
    }
  }
Guolin Ke's avatar
Guolin Ke committed
189
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
190
191
192
  if (store_raw_) {
    dataset->SetHasRaw(true);
  }
Guolin Ke's avatar
Guolin Ke committed
193
194
  data_size_t num_global_data = 0;
  std::vector<data_size_t> used_data_indices;
195
  auto bin_filename = CheckCanLoadFromBin(filename);
196
  bool is_load_from_binary = false;
197
  if (bin_filename.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
198
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_));
Guolin Ke's avatar
Guolin Ke committed
199
200
201
202
    if (parser == nullptr) {
      Log::Fatal("Could not recognize data format of %s", filename);
    }
    dataset->data_filename_ = filename;
Guolin Ke's avatar
Guolin Ke committed
203
    dataset->label_idx_ = label_idx_;
204
    dataset->metadata_.Init(filename);
Guolin Ke's avatar
Guolin Ke committed
205
    if (!config_.two_round) {
Guolin Ke's avatar
Guolin Ke committed
206
      // read data to memory
207
      auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
208
209
210
      dataset->num_data_ = static_cast<data_size_t>(text_data.size());
      // sample data
      auto sample_data = SampleTextDataFromMemory(text_data);
211
212
      CheckSampleSize(sample_data.size(),
                      static_cast<size_t>(dataset->num_data_));
Guolin Ke's avatar
Guolin Ke committed
213
      // construct feature bin mappers
Guolin Ke's avatar
Guolin Ke committed
214
      ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get());
215
216
217
      if (dataset->has_raw()) {
        dataset->ResizeRaw(dataset->num_data_);
      }
Guolin Ke's avatar
Guolin Ke committed
218
      // initialize label
219
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
220
      // extract features
Guolin Ke's avatar
Guolin Ke committed
221
      ExtractFeaturesFromMemory(&text_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
222
223
224
225
226
227
228
229
230
      text_data.clear();
    } else {
      // sample data from file
      auto sample_data = SampleTextDataFromFile(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
      if (used_data_indices.size() > 0) {
        dataset->num_data_ = static_cast<data_size_t>(used_data_indices.size());
      } else {
        dataset->num_data_ = num_global_data;
      }
231
232
      CheckSampleSize(sample_data.size(),
                      static_cast<size_t>(dataset->num_data_));
Guolin Ke's avatar
Guolin Ke committed
233
      // construct feature bin mappers
Guolin Ke's avatar
Guolin Ke committed
234
      ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get());
235
236
237
      if (dataset->has_raw()) {
        dataset->ResizeRaw(dataset->num_data_);
      }
Guolin Ke's avatar
Guolin Ke committed
238
      // initialize label
239
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
240
      Log::Debug("Making second pass...");
Guolin Ke's avatar
Guolin Ke committed
241
      // extract features
Guolin Ke's avatar
Guolin Ke committed
242
      ExtractFeaturesFromFile(filename, parser.get(), used_data_indices, dataset.get());
Guolin Ke's avatar
Guolin Ke committed
243
244
245
    }
  } else {
    // load data from binary file
246
247
    is_load_from_binary = true;
    Log::Info("Load from binary file %s", bin_filename.c_str());
248
    dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), rank, num_machines, &num_global_data, &used_data_indices));
Guolin Ke's avatar
Guolin Ke committed
249
250
251
252
  }
  // check meta data
  dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);
  // need to check training data
253
254
  CheckDataset(dataset.get(), is_load_from_binary);

Guolin Ke's avatar
Guolin Ke committed
255
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
256
257
258
259
}



260
Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, const Dataset* train_data) {
Guolin Ke's avatar
Guolin Ke committed
261
262
  data_size_t num_global_data = 0;
  std::vector<data_size_t> used_data_indices;
Guolin Ke's avatar
Guolin Ke committed
263
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
264
265
266
  if (store_raw_) {
    dataset->SetHasRaw(true);
  }
267
268
  auto bin_filename = CheckCanLoadFromBin(filename);
  if (bin_filename.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
269
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_));
Guolin Ke's avatar
Guolin Ke committed
270
271
272
273
    if (parser == nullptr) {
      Log::Fatal("Could not recognize data format of %s", filename);
    }
    dataset->data_filename_ = filename;
Guolin Ke's avatar
Guolin Ke committed
274
    dataset->label_idx_ = label_idx_;
275
    dataset->metadata_.Init(filename);
Guolin Ke's avatar
Guolin Ke committed
276
    if (!config_.two_round) {
Guolin Ke's avatar
Guolin Ke committed
277
278
279
280
      // read data in memory
      auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, 0, 1, &num_global_data, &used_data_indices);
      dataset->num_data_ = static_cast<data_size_t>(text_data.size());
      // initialize label
281
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
282
      dataset->CreateValid(train_data);
283
284
285
      if (dataset->has_raw()) {
        dataset->ResizeRaw(dataset->num_data_);
      }
Guolin Ke's avatar
Guolin Ke committed
286
      // extract features
Guolin Ke's avatar
Guolin Ke committed
287
      ExtractFeaturesFromMemory(&text_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
288
289
      text_data.clear();
    } else {
Guolin Ke's avatar
Guolin Ke committed
290
      TextReader<data_size_t> text_reader(filename, config_.header);
Guolin Ke's avatar
Guolin Ke committed
291
292
293
294
      // Get number of lines of data file
      dataset->num_data_ = static_cast<data_size_t>(text_reader.CountLine());
      num_global_data = dataset->num_data_;
      // initialize label
295
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
296
      dataset->CreateValid(train_data);
297
298
299
      if (dataset->has_raw()) {
        dataset->ResizeRaw(dataset->num_data_);
      }
Guolin Ke's avatar
Guolin Ke committed
300
      // extract features
Guolin Ke's avatar
Guolin Ke committed
301
      ExtractFeaturesFromFile(filename, parser.get(), used_data_indices, dataset.get());
Guolin Ke's avatar
Guolin Ke committed
302
303
304
    }
  } else {
    // load data from binary file
305
    dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), 0, 1, &num_global_data, &used_data_indices));
Guolin Ke's avatar
Guolin Ke committed
306
307
308
309
  }
  // not need to check validation data
  // check meta data
  dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
310
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
311
312
}

313
314
315
Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* bin_filename,
                                        int rank, int num_machines, int* num_global_data,
                                        std::vector<data_size_t>* used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
316
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
317
  auto reader = VirtualFileReader::Make(bin_filename);
Guolin Ke's avatar
Guolin Ke committed
318
  dataset->data_filename_ = data_filename;
319
  if (!reader->Init()) {
Guolin Ke's avatar
Guolin Ke committed
320
321
322
323
324
    Log::Fatal("Could not read binary data from %s", bin_filename);
  }

  // buffer to read binary file
  size_t buffer_size = 16 * 1024 * 1024;
Guolin Ke's avatar
Guolin Ke committed
325
  auto buffer = std::vector<char>(buffer_size);
326

327
328
  // check token
  size_t size_of_token = std::strlen(Dataset::binary_file_token);
329
330
331
332
  size_t read_cnt = reader->Read(
      buffer.data(),
      VirtualFileWriter::AlignedSize(sizeof(char) * size_of_token));
  if (read_cnt < sizeof(char) * size_of_token) {
333
334
335
    Log::Fatal("Binary file error: token has the wrong size");
  }
  if (std::string(buffer.data()) != std::string(Dataset::binary_file_token)) {
336
    Log::Fatal("Input file is not LightGBM binary file");
337
  }
Guolin Ke's avatar
Guolin Ke committed
338
339

  // read size of header
340
  read_cnt = reader->Read(buffer.data(), sizeof(size_t));
Guolin Ke's avatar
Guolin Ke committed
341

342
  if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
343
344
345
    Log::Fatal("Binary file error: header has the wrong size");
  }

Guolin Ke's avatar
Guolin Ke committed
346
  size_t size_of_head = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
347
348
349
350

  // re-allocmate space if not enough
  if (size_of_head > buffer_size) {
    buffer_size = size_of_head;
Guolin Ke's avatar
Guolin Ke committed
351
    buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
352
353
  }
  // read header
354
  read_cnt = reader->Read(buffer.data(), size_of_head);
Guolin Ke's avatar
Guolin Ke committed
355
356
357
358
359

  if (read_cnt != size_of_head) {
    Log::Fatal("Binary file error: header is incorrect");
  }
  // get header
Guolin Ke's avatar
Guolin Ke committed
360
  const char* mem_ptr = buffer.data();
Guolin Ke's avatar
Guolin Ke committed
361
  dataset->num_data_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
362
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->num_data_));
Guolin Ke's avatar
Guolin Ke committed
363
  dataset->num_features_ = *(reinterpret_cast<const int*>(mem_ptr));
364
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->num_features_));
Guolin Ke's avatar
Guolin Ke committed
365
  dataset->num_total_features_ = *(reinterpret_cast<const int*>(mem_ptr));
366
367
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(dataset->num_total_features_));
Guolin Ke's avatar
Guolin Ke committed
368
  dataset->label_idx_ = *(reinterpret_cast<const int*>(mem_ptr));
369
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->label_idx_));
370
  dataset->max_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
371
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->max_bin_));
372
  dataset->bin_construct_sample_cnt_ = *(reinterpret_cast<const int*>(mem_ptr));
373
374
  mem_ptr += VirtualFileWriter::AlignedSize(
      sizeof(dataset->bin_construct_sample_cnt_));
375
  dataset->min_data_in_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
376
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->min_data_in_bin_));
377
  dataset->use_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
378
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->use_missing_));
379
  dataset->zero_as_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
380
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->zero_as_missing_));
381
382
  dataset->has_raw_ = *(reinterpret_cast<const bool*>(mem_ptr));
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->has_raw_));
Guolin Ke's avatar
Guolin Ke committed
383
384
  const int* tmp_feature_map = reinterpret_cast<const int*>(mem_ptr);
  dataset->used_feature_map_.clear();
Guolin Ke's avatar
Guolin Ke committed
385
  for (int i = 0; i < dataset->num_total_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
386
387
    dataset->used_feature_map_.push_back(tmp_feature_map[i]);
  }
388
389
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int) *
                                            dataset->num_total_features_);
Guolin Ke's avatar
Guolin Ke committed
390
391
  // num_groups
  dataset->num_groups_ = *(reinterpret_cast<const int*>(mem_ptr));
392
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->num_groups_));
Guolin Ke's avatar
Guolin Ke committed
393
394
395
396
397
398
  // real_feature_idx_
  const int* tmp_ptr_real_feature_idx_ = reinterpret_cast<const int*>(mem_ptr);
  dataset->real_feature_idx_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->real_feature_idx_.push_back(tmp_ptr_real_feature_idx_[i]);
  }
399
400
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * dataset->num_features_);
Guolin Ke's avatar
Guolin Ke committed
401
402
403
404
405
406
  // feature2group
  const int* tmp_ptr_feature2group = reinterpret_cast<const int*>(mem_ptr);
  dataset->feature2group_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->feature2group_.push_back(tmp_ptr_feature2group[i]);
  }
407
408
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * dataset->num_features_);
Guolin Ke's avatar
Guolin Ke committed
409
410
411
412
413
414
  // feature2subfeature
  const int* tmp_ptr_feature2subfeature = reinterpret_cast<const int*>(mem_ptr);
  dataset->feature2subfeature_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->feature2subfeature_.push_back(tmp_ptr_feature2subfeature[i]);
  }
415
416
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * dataset->num_features_);
Guolin Ke's avatar
Guolin Ke committed
417
418
419
420
421
422
423
424
425
426
427
  // group_bin_boundaries
  const uint64_t* tmp_ptr_group_bin_boundaries = reinterpret_cast<const uint64_t*>(mem_ptr);
  dataset->group_bin_boundaries_.clear();
  for (int i = 0; i < dataset->num_groups_ + 1; ++i) {
    dataset->group_bin_boundaries_.push_back(tmp_ptr_group_bin_boundaries[i]);
  }
  mem_ptr += sizeof(uint64_t) * (dataset->num_groups_ + 1);

  // group_feature_start_
  const int* tmp_ptr_group_feature_start = reinterpret_cast<const int*>(mem_ptr);
  dataset->group_feature_start_.clear();
428
  for (int i = 0; i < dataset->num_groups_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
429
430
    dataset->group_feature_start_.push_back(tmp_ptr_group_feature_start[i]);
  }
431
432
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * (dataset->num_groups_));
Guolin Ke's avatar
Guolin Ke committed
433
434
435
436
437
438
439

  // group_feature_cnt_
  const int* tmp_ptr_group_feature_cnt = reinterpret_cast<const int*>(mem_ptr);
  dataset->group_feature_cnt_.clear();
  for (int i = 0; i < dataset->num_groups_; ++i) {
    dataset->group_feature_cnt_.push_back(tmp_ptr_group_feature_cnt[i]);
  }
440
441
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * (dataset->num_groups_));
Guolin Ke's avatar
Guolin Ke committed
442

Belinda Trotta's avatar
Belinda Trotta committed
443
  if (!config_.max_bin_by_feature.empty()) {
444
445
    CHECK_EQ(static_cast<size_t>(dataset->num_total_features_), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
Belinda Trotta's avatar
Belinda Trotta committed
446
447
448
449
450
451
452
453
454
    dataset->max_bin_by_feature_.resize(dataset->num_total_features_);
    dataset->max_bin_by_feature_.assign(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end());
  } else {
    const int32_t* tmp_ptr_max_bin_by_feature = reinterpret_cast<const int32_t*>(mem_ptr);
    dataset->max_bin_by_feature_.clear();
    for (int i = 0; i < dataset->num_total_features_; ++i) {
      dataset->max_bin_by_feature_.push_back(tmp_ptr_max_bin_by_feature[i]);
    }
  }
455
456
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int32_t) *
                                            (dataset->num_total_features_));
Belinda Trotta's avatar
Belinda Trotta committed
457
458
459
460
  if (ArrayArgs<int32_t>::CheckAll(dataset->max_bin_by_feature_, -1)) {
    dataset->max_bin_by_feature_.clear();
  }

Guolin Ke's avatar
Guolin Ke committed
461
  // get feature names
Guolin Ke's avatar
Guolin Ke committed
462
  dataset->feature_names_.clear();
Guolin Ke's avatar
Guolin Ke committed
463
464
465
  // write feature names
  for (int i = 0; i < dataset->num_total_features_; ++i) {
    int str_len = *(reinterpret_cast<const int*>(mem_ptr));
466
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int));
Guolin Ke's avatar
Guolin Ke committed
467
    std::stringstream str_buf;
468
    auto tmp_arr = reinterpret_cast<const char*>(mem_ptr);
Guolin Ke's avatar
Guolin Ke committed
469
    for (int j = 0; j < str_len; ++j) {
470
      char tmp_char = tmp_arr[j];
Guolin Ke's avatar
Guolin Ke committed
471
472
      str_buf << tmp_char;
    }
473
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(char) * str_len);
Guolin Ke's avatar
Guolin Ke committed
474
    dataset->feature_names_.emplace_back(str_buf.str());
Guolin Ke's avatar
Guolin Ke committed
475
  }
476
477
478
479
  // get forced_bin_bounds_
  dataset->forced_bin_bounds_ = std::vector<std::vector<double>>(dataset->num_total_features_, std::vector<double>());
  for (int i = 0; i < dataset->num_total_features_; ++i) {
    int num_bounds = *(reinterpret_cast<const int*>(mem_ptr));
480
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int));
481
    dataset->forced_bin_bounds_[i] = std::vector<double>();
482
483
    const double* tmp_ptr_forced_bounds =
        reinterpret_cast<const double*>(mem_ptr);
484
485
486
487
488
489
    for (int j = 0; j < num_bounds; ++j) {
      double bound = tmp_ptr_forced_bounds[j];
      dataset->forced_bin_bounds_[i].push_back(bound);
    }
    mem_ptr += num_bounds * sizeof(double);
  }
Guolin Ke's avatar
Guolin Ke committed
490
491

  // read size of meta data
492
  read_cnt = reader->Read(buffer.data(), sizeof(size_t));
Guolin Ke's avatar
Guolin Ke committed
493

494
  if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
495
496
497
    Log::Fatal("Binary file error: meta data has the wrong size");
  }

Guolin Ke's avatar
Guolin Ke committed
498
  size_t size_of_metadata = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
499
500
501
502

  // re-allocate space if not enough
  if (size_of_metadata > buffer_size) {
    buffer_size = size_of_metadata;
Guolin Ke's avatar
Guolin Ke committed
503
    buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
504
505
  }
  //  read meta data
506
  read_cnt = reader->Read(buffer.data(), size_of_metadata);
Guolin Ke's avatar
Guolin Ke committed
507
508
509
510
511

  if (read_cnt != size_of_metadata) {
    Log::Fatal("Binary file error: meta data is incorrect");
  }
  // load meta data
Guolin Ke's avatar
Guolin Ke committed
512
  dataset->metadata_.LoadFromMemory(buffer.data());
Guolin Ke's avatar
Guolin Ke committed
513

514
515
  *num_global_data = dataset->num_data_;
  used_data_indices->clear();
Guolin Ke's avatar
Guolin Ke committed
516
  // sample local used data if need to partition
Guolin Ke's avatar
Guolin Ke committed
517
  if (num_machines > 1 && !config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
518
519
520
521
    const data_size_t* query_boundaries = dataset->metadata_.query_boundaries();
    if (query_boundaries == nullptr) {
      // if not contain query file, minimal sample unit is one record
      for (data_size_t i = 0; i < dataset->num_data_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
522
        if (random_.NextShort(0, num_machines) == rank) {
523
          used_data_indices->push_back(i);
Guolin Ke's avatar
Guolin Ke committed
524
525
526
527
528
529
530
531
532
        }
      }
    } else {
      // if contain query file, minimal sample unit is one query
      data_size_t num_queries = dataset->metadata_.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      for (data_size_t i = 0; i < dataset->num_data_; ++i) {
        if (qid >= num_queries) {
533
534
          Log::Fatal("Current query exceeds the range of the query file,\n"
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
535
536
537
538
        }
        if (i >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
539
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
540
541
542
543
544
            is_query_used = true;
          }
          ++qid;
        }
        if (is_query_used) {
545
          used_data_indices->push_back(i);
Guolin Ke's avatar
Guolin Ke committed
546
547
548
        }
      }
    }
549
    dataset->num_data_ = static_cast<data_size_t>((*used_data_indices).size());
Guolin Ke's avatar
Guolin Ke committed
550
  }
551
  dataset->metadata_.PartitionLabel(*used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
552
  // read feature data
Guolin Ke's avatar
Guolin Ke committed
553
  for (int i = 0; i < dataset->num_groups_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
554
    // read feature size
555
556
    read_cnt = reader->Read(buffer.data(), sizeof(size_t));
    if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
557
558
      Log::Fatal("Binary file error: feature %d has the wrong size", i);
    }
Guolin Ke's avatar
Guolin Ke committed
559
    size_t size_of_feature = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
560
561
562
    // re-allocate space if not enough
    if (size_of_feature > buffer_size) {
      buffer_size = size_of_feature;
Guolin Ke's avatar
Guolin Ke committed
563
      buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
564
565
    }

566
    read_cnt = reader->Read(buffer.data(), size_of_feature);
Guolin Ke's avatar
Guolin Ke committed
567
568
569
570

    if (read_cnt != size_of_feature) {
      Log::Fatal("Binary file error: feature %d is incorrect, read count: %d", i, read_cnt);
    }
Guolin Ke's avatar
Guolin Ke committed
571
    dataset->feature_groups_.emplace_back(std::unique_ptr<FeatureGroup>(
572
573
      new FeatureGroup(buffer.data(),
                       *num_global_data,
574
                       *used_data_indices, i)));
Guolin Ke's avatar
Guolin Ke committed
575
  }
Guolin Ke's avatar
Guolin Ke committed
576
  dataset->feature_groups_.shrink_to_fit();
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612

  // raw data
  dataset->numeric_feature_map_ = std::vector<int>(dataset->num_features_, false);
  dataset->num_numeric_features_ = 0;
  for (int i = 0; i < dataset->num_features_; ++i) {
    if (dataset->FeatureBinMapper(i)->bin_type() == BinType::CategoricalBin) {
      dataset->numeric_feature_map_[i] = -1;
    } else {
      dataset->numeric_feature_map_[i] = dataset->num_numeric_features_;
      ++dataset->num_numeric_features_;
    }
  }
  if (dataset->has_raw()) {
    dataset->ResizeRaw(dataset->num_data());
      size_t row_size = dataset->num_numeric_features_ * sizeof(float);
      if (row_size > buffer_size) {
        buffer_size = row_size;
        buffer.resize(buffer_size);
      }
    for (int i = 0; i < dataset->num_data(); ++i) {
      read_cnt = reader->Read(buffer.data(), row_size);
      if (read_cnt != row_size) {
        Log::Fatal("Binary file error: row %d of raw data is incorrect, read count: %d", i, read_cnt);
      }
      mem_ptr = buffer.data();
      const float* tmp_ptr_raw_row = reinterpret_cast<const float*>(mem_ptr);
      for (int j = 0; j < dataset->num_features(); ++j) {
        int feat_ind = dataset->numeric_feature_map_[j];
        if (feat_ind >= 0) {
          dataset->raw_data_[feat_ind][i] = tmp_ptr_raw_row[feat_ind];
        }
      }
      mem_ptr += row_size;
    }
  }

Guolin Ke's avatar
Guolin Ke committed
613
  dataset->is_finish_load_ = true;
Guolin Ke's avatar
Guolin Ke committed
614
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
615
616
}

617

618
619
620
Dataset* DatasetLoader::ConstructFromSampleData(double** sample_values,
                                                int** sample_indices, int num_col, const int* num_per_col,
                                                size_t total_sample_size, data_size_t num_data) {
621
  CheckSampleSize(total_sample_size, static_cast<size_t>(num_data));
622
623
624
625
626
  int num_total_features = num_col;
  if (Network::num_machines() > 1) {
    num_total_features = Network::GlobalSyncUpByMax(num_total_features);
  }
  std::vector<std::unique_ptr<BinMapper>> bin_mappers(num_total_features);
627
628
  // fill feature_names_ if not header
  if (feature_names_.empty()) {
629
    for (int i = 0; i < num_col; ++i) {
630
631
632
633
634
      std::stringstream str_buf;
      str_buf << "Column_" << i;
      feature_names_.push_back(str_buf.str());
    }
  }
Belinda Trotta's avatar
Belinda Trotta committed
635
  if (!config_.max_bin_by_feature.empty()) {
636
637
    CHECK_EQ(static_cast<size_t>(num_col), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
Belinda Trotta's avatar
Belinda Trotta committed
638
  }
639
640
641
642
643

  // get forced split
  std::string forced_bins_path = config_.forcedbins_filename;
  std::vector<std::vector<double>> forced_bin_bounds = DatasetLoader::GetForcedBins(forced_bins_path, num_col, categorical_features_);

Guolin Ke's avatar
Guolin Ke committed
644
  const data_size_t filter_cnt = static_cast<data_size_t>(
Guolin Ke's avatar
Guolin Ke committed
645
    static_cast<double>(config_.min_data_in_leaf * total_sample_size) / num_data);
646
647
648
649
650
651
652
653
654
655
656
657
658
  if (Network::num_machines() == 1) {
    // if only one machine, find bin locally
    OMP_INIT_EX();
    #pragma omp parallel for schedule(guided)
    for (int i = 0; i < num_col; ++i) {
      OMP_LOOP_EX_BEGIN();
      if (ignore_features_.count(i) > 0) {
        bin_mappers[i] = nullptr;
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(i)) {
        bin_type = BinType::CategoricalBin;
659
660
661
662
        bool feat_is_unconstrained = ((config_.monotone_constraints.size() == 0) || (config_.monotone_constraints[i] == 0));
        if (!feat_is_unconstrained) {
            Log::Fatal("The output cannot be monotone with respect to categorical features");
        }
663
664
      }
      bin_mappers[i].reset(new BinMapper());
Belinda Trotta's avatar
Belinda Trotta committed
665
666
      if (config_.max_bin_by_feature.empty()) {
        bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
667
                                config_.max_bin, config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter,
668
669
                                bin_type, config_.use_missing, config_.zero_as_missing,
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
670
671
      } else {
        bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
672
                                config_.max_bin_by_feature[i], config_.min_data_in_bin,
673
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
674
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
675
      }
676
677
678
679
680
681
682
683
684
685
686
687
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();
  } else {
    // if have multi-machines, need to find bin distributed
    // different machines will find bin for different features
    int num_machines = Network::num_machines();
    int rank = Network::rank();
    // start and len will store the process feature indices for different machines
    // machine i will find bins for features in [ start[i], start[i] + len[i] )
    std::vector<int> start(num_machines);
    std::vector<int> len(num_machines);
688
    int step = (num_total_features + num_machines - 1) / num_machines;
689
690
691
692
    if (step < 1) { step = 1; }

    start[0] = 0;
    for (int i = 0; i < num_machines - 1; ++i) {
693
      len[i] = std::min(step, num_total_features - start[i]);
694
695
      start[i + 1] = start[i] + len[i];
    }
696
    len[num_machines - 1] = num_total_features - start[num_machines - 1];
697
698
699
700
701
702
703
704
705
706
707
708
    OMP_INIT_EX();
    #pragma omp parallel for schedule(guided)
    for (int i = 0; i < len[rank]; ++i) {
      OMP_LOOP_EX_BEGIN();
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(start[rank] + i)) {
        bin_type = BinType::CategoricalBin;
      }
      bin_mappers[i].reset(new BinMapper());
709
710
711
      if (num_col <= start[rank] + i) {
        continue;
      }
Belinda Trotta's avatar
Belinda Trotta committed
712
      if (config_.max_bin_by_feature.empty()) {
713
714
        bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
                                total_sample_size, config_.max_bin, config_.min_data_in_bin,
715
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
716
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
717
      } else {
718
719
        bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
                                total_sample_size, config_.max_bin_by_feature[start[rank] + i],
720
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
721
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
722
      }
723
724
      OMP_LOOP_EX_END();
    }
Guolin Ke's avatar
Guolin Ke committed
725
    OMP_THROW_EX();
726
    comm_size_t self_buf_size = 0;
727
    for (int i = 0; i < len[rank]; ++i) {
728
729
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
730
      }
731
      self_buf_size += static_cast<comm_size_t>(bin_mappers[i]->SizesInByte());
Guolin Ke's avatar
Guolin Ke committed
732
    }
733
734
    std::vector<char> input_buffer(self_buf_size);
    auto cp_ptr = input_buffer.data();
735
736
737
738
    for (int i = 0; i < len[rank]; ++i) {
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
739
740
      bin_mappers[i]->CopyTo(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
741
742
743
      // free
      bin_mappers[i].reset(nullptr);
    }
744
745
746
747
    std::vector<comm_size_t> size_len = Network::GlobalArray(self_buf_size);
    std::vector<comm_size_t> size_start(num_machines, 0);
    for (int i = 1; i < num_machines; ++i) {
      size_start[i] = size_start[i - 1] + size_len[i - 1];
748
    }
749
750
    comm_size_t total_buffer_size = size_start[num_machines - 1] + size_len[num_machines - 1];
    std::vector<char> output_buffer(total_buffer_size);
751
    // gather global feature bin mappers
752
753
    Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), total_buffer_size);
    cp_ptr = output_buffer.data();
754
    // restore features bins from buffer
755
    for (int i = 0; i < num_total_features; ++i) {
756
757
758
759
760
      if (ignore_features_.count(i) > 0) {
        bin_mappers[i] = nullptr;
        continue;
      }
      bin_mappers[i].reset(new BinMapper());
761
762
      bin_mappers[i]->CopyFrom(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
763
    }
Guolin Ke's avatar
Guolin Ke committed
764
  }
Guolin Ke's avatar
Guolin Ke committed
765
  auto dataset = std::unique_ptr<Dataset>(new Dataset(num_data));
Guolin Ke's avatar
Guolin Ke committed
766
  dataset->Construct(&bin_mappers, num_total_features, forced_bin_bounds, sample_indices, sample_values, num_per_col, num_col, total_sample_size, config_);
767
768
769
  if (dataset->has_raw()) {
    dataset->ResizeRaw(num_data);
  }
770
  dataset->set_feature_names(feature_names_);
Guolin Ke's avatar
Guolin Ke committed
771
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
772
}
Guolin Ke's avatar
Guolin Ke committed
773
774
775
776


// ---- private functions ----

777
void DatasetLoader::CheckDataset(const Dataset* dataset, bool is_load_from_binary) {
Guolin Ke's avatar
Guolin Ke committed
778
  if (dataset->num_data_ <= 0) {
Guolin Ke's avatar
Guolin Ke committed
779
    Log::Fatal("Data file %s is empty", dataset->data_filename_.c_str());
Guolin Ke's avatar
Guolin Ke committed
780
  }
781
782
  if (dataset->feature_names_.size() != static_cast<size_t>(dataset->num_total_features_)) {
    Log::Fatal("Size of feature name error, should be %d, got %d", dataset->num_total_features_,
783
               static_cast<int>(dataset->feature_names_.size()));
784
  }
Guolin Ke's avatar
Guolin Ke committed
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
  bool is_feature_order_by_group = true;
  int last_group = -1;
  int last_sub_feature = -1;
  // if features are ordered, not need to use hist_buf
  for (int i = 0; i < dataset->num_features_; ++i) {
    int group = dataset->feature2group_[i];
    int sub_feature = dataset->feature2subfeature_[i];
    if (group < last_group) {
      is_feature_order_by_group = false;
    } else if (group == last_group) {
      if (sub_feature <= last_sub_feature) {
        is_feature_order_by_group = false;
        break;
      }
    }
    last_group = group;
    last_sub_feature = sub_feature;
  }
  if (!is_feature_order_by_group) {
804
    Log::Fatal("Features in dataset should be ordered by group");
Guolin Ke's avatar
Guolin Ke committed
805
  }
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831

  if (is_load_from_binary) {
    if (dataset->max_bin_ != config_.max_bin) {
      Log::Fatal("Dataset max_bin %d != config %d", dataset->max_bin_, config_.max_bin);
    }
    if (dataset->min_data_in_bin_ != config_.min_data_in_bin) {
      Log::Fatal("Dataset min_data_in_bin %d != config %d", dataset->min_data_in_bin_, config_.min_data_in_bin);
    }
    if (dataset->use_missing_ != config_.use_missing) {
      Log::Fatal("Dataset use_missing %d != config %d", dataset->use_missing_, config_.use_missing);
    }
    if (dataset->zero_as_missing_ != config_.zero_as_missing) {
      Log::Fatal("Dataset zero_as_missing %d != config %d", dataset->zero_as_missing_, config_.zero_as_missing);
    }
    if (dataset->bin_construct_sample_cnt_ != config_.bin_construct_sample_cnt) {
      Log::Fatal("Dataset bin_construct_sample_cnt %d != config %d", dataset->bin_construct_sample_cnt_, config_.bin_construct_sample_cnt);
    }
    if ((dataset->max_bin_by_feature_.size() != config_.max_bin_by_feature.size()) ||
        !std::equal(dataset->max_bin_by_feature_.begin(), dataset->max_bin_by_feature_.end(),
            config_.max_bin_by_feature.begin())) {
      Log::Fatal("Dataset max_bin_by_feature does not match with config");
    }

    int label_idx = -1;
    if (Common::AtoiAndCheck(config_.label_column.c_str(), &label_idx)) {
      if (dataset->label_idx_ != label_idx) {
832
        Log::Fatal("Dataset label_idx %d != config %d", dataset->label_idx_, label_idx);
833
834
835
836
837
      }
    } else {
      Log::Info("Recommend use integer for label index when loading data from binary for sanity check.");
    }
  }
Guolin Ke's avatar
Guolin Ke committed
838
839
840
}

std::vector<std::string> DatasetLoader::LoadTextDataToMemory(const char* filename, const Metadata& metadata,
841
842
                                                             int rank, int num_machines, int* num_global_data,
                                                             std::vector<data_size_t>* used_data_indices) {
843
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
844
  used_data_indices->clear();
Guolin Ke's avatar
Guolin Ke committed
845
  if (num_machines == 1 || config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
846
847
848
849
850
851
852
853
854
    // read all lines
    *num_global_data = text_reader.ReadAllLines();
  } else {  // need partition data
            // get query data
    const data_size_t* query_boundaries = metadata.query_boundaries();

    if (query_boundaries == nullptr) {
      // if not contain query data, minimal sample unit is one record
      *num_global_data = text_reader.ReadAndFilterLines([this, rank, num_machines](data_size_t) {
Guolin Ke's avatar
Guolin Ke committed
855
        if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
856
857
858
859
860
861
862
863
864
865
866
867
868
869
          return true;
        } else {
          return false;
        }
      }, used_data_indices);
    } else {
      // if contain query data, minimal sample unit is one query
      data_size_t num_queries = metadata.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      *num_global_data = text_reader.ReadAndFilterLines(
        [this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries]
      (data_size_t line_idx) {
        if (qid >= num_queries) {
870
871
          Log::Fatal("Current query exceeds the range of the query file,\n"
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
872
873
874
875
        }
        if (line_idx >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
876
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
877
878
879
880
881
882
883
884
885
886
887
888
            is_query_used = true;
          }
          ++qid;
        }
        return is_query_used;
      }, used_data_indices);
    }
  }
  return std::move(text_reader.Lines());
}

std::vector<std::string> DatasetLoader::SampleTextDataFromMemory(const std::vector<std::string>& data) {
Guolin Ke's avatar
Guolin Ke committed
889
  int sample_cnt = config_.bin_construct_sample_cnt;
890
891
  if (static_cast<size_t>(sample_cnt) > data.size()) {
    sample_cnt = static_cast<int>(data.size());
892
  }
893
  auto sample_indices = random_.Sample(static_cast<int>(data.size()), sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
894
  std::vector<std::string> out(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
895
896
  for (size_t i = 0; i < sample_indices.size(); ++i) {
    const size_t idx = sample_indices[i];
Guolin Ke's avatar
Guolin Ke committed
897
    out[i] = data[idx];
Guolin Ke's avatar
Guolin Ke committed
898
899
900
901
  }
  return out;
}

902
903
904
std::vector<std::string> DatasetLoader::SampleTextDataFromFile(const char* filename, const Metadata& metadata,
                                                               int rank, int num_machines, int* num_global_data,
                                                               std::vector<data_size_t>* used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
905
  const data_size_t sample_cnt = static_cast<data_size_t>(config_.bin_construct_sample_cnt);
906
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
907
  std::vector<std::string> out_data;
Guolin Ke's avatar
Guolin Ke committed
908
  if (num_machines == 1 || config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
909
    *num_global_data = static_cast<data_size_t>(text_reader.SampleFromFile(&random_, sample_cnt, &out_data));
Guolin Ke's avatar
Guolin Ke committed
910
911
912
913
914
915
916
  } else {  // need partition data
            // get query data
    const data_size_t* query_boundaries = metadata.query_boundaries();
    if (query_boundaries == nullptr) {
      // if not contain query file, minimal sample unit is one record
      *num_global_data = text_reader.SampleAndFilterFromFile([this, rank, num_machines]
      (data_size_t) {
Guolin Ke's avatar
Guolin Ke committed
917
        if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
918
919
920
921
          return true;
        } else {
          return false;
        }
Guolin Ke's avatar
Guolin Ke committed
922
      }, used_data_indices, &random_, sample_cnt, &out_data);
Guolin Ke's avatar
Guolin Ke committed
923
924
925
926
927
928
929
930
931
    } else {
      // if contain query file, minimal sample unit is one query
      data_size_t num_queries = metadata.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      *num_global_data = text_reader.SampleAndFilterFromFile(
        [this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries]
      (data_size_t line_idx) {
        if (qid >= num_queries) {
932
933
          Log::Fatal("Query id exceeds the range of the query file, "
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
934
935
936
937
        }
        if (line_idx >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
938
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
939
940
941
942
943
            is_query_used = true;
          }
          ++qid;
        }
        return is_query_used;
Guolin Ke's avatar
Guolin Ke committed
944
      }, used_data_indices, &random_, sample_cnt, &out_data);
Guolin Ke's avatar
Guolin Ke committed
945
946
947
948
949
    }
  }
  return out_data;
}

950
951
952
void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
                                                    const std::vector<std::string>& sample_data,
                                                    const Parser* parser, Dataset* dataset) {
Guolin Ke's avatar
Guolin Ke committed
953
  std::vector<std::vector<double>> sample_values;
Guolin Ke's avatar
Guolin Ke committed
954
  std::vector<std::vector<int>> sample_indices;
Guolin Ke's avatar
Guolin Ke committed
955
956
  std::vector<std::pair<int, double>> oneline_features;
  double label;
Guolin Ke's avatar
Guolin Ke committed
957
  for (int i = 0; i < static_cast<int>(sample_data.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
958
959
960
961
    oneline_features.clear();
    // parse features
    parser->ParseOneLine(sample_data[i].c_str(), &oneline_features, &label);
    for (std::pair<int, double>& inner_data : oneline_features) {
962
      if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
Guolin Ke's avatar
Guolin Ke committed
963
964
        sample_values.resize(inner_data.first + 1);
        sample_indices.resize(inner_data.first + 1);
965
      }
Guolin Ke's avatar
Guolin Ke committed
966
      if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
Guolin Ke's avatar
Guolin Ke committed
967
968
        sample_values[inner_data.first].emplace_back(inner_data.second);
        sample_indices[inner_data.first].emplace_back(i);
Guolin Ke's avatar
Guolin Ke committed
969
970
971
972
      }
    }
  }

Guolin Ke's avatar
Guolin Ke committed
973
  dataset->feature_groups_.clear();
974
975
976
977
978
  dataset->num_total_features_ = std::max(static_cast<int>(sample_values.size()), parser->NumFeatures());
  if (num_machines > 1) {
    dataset->num_total_features_ = Network::GlobalSyncUpByMax(dataset->num_total_features_);
  }
  if (!feature_names_.empty()) {
979
    CHECK_EQ(dataset->num_total_features_, static_cast<int>(feature_names_.size()));
980
  }
Guolin Ke's avatar
Guolin Ke committed
981

Belinda Trotta's avatar
Belinda Trotta committed
982
  if (!config_.max_bin_by_feature.empty()) {
983
984
    CHECK_EQ(static_cast<size_t>(dataset->num_total_features_), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
Belinda Trotta's avatar
Belinda Trotta committed
985
986
  }

987
988
  // get forced split
  std::string forced_bins_path = config_.forcedbins_filename;
989
990
  std::vector<std::vector<double>> forced_bin_bounds = DatasetLoader::GetForcedBins(forced_bins_path,
                                                                                    dataset->num_total_features_,
991
992
                                                                                    categorical_features_);

Guolin Ke's avatar
Guolin Ke committed
993
994
995
996
997
998
  // check the range of label_idx, weight_idx and group_idx
  CHECK(label_idx_ >= 0 && label_idx_ <= dataset->num_total_features_);
  CHECK(weight_idx_ < 0 || weight_idx_ < dataset->num_total_features_);
  CHECK(group_idx_ < 0 || group_idx_ < dataset->num_total_features_);

  // fill feature_names_ if not header
Guolin Ke's avatar
Guolin Ke committed
999
  if (feature_names_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1000
1001
1002
1003
1004
1005
    for (int i = 0; i < dataset->num_total_features_; ++i) {
      std::stringstream str_buf;
      str_buf << "Column_" << i;
      feature_names_.push_back(str_buf.str());
    }
  }
1006
  dataset->set_feature_names(feature_names_);
Guolin Ke's avatar
Guolin Ke committed
1007
  std::vector<std::unique_ptr<BinMapper>> bin_mappers(dataset->num_total_features_);
Guolin Ke's avatar
Guolin Ke committed
1008
  const data_size_t filter_cnt = static_cast<data_size_t>(
Guolin Ke's avatar
Guolin Ke committed
1009
    static_cast<double>(config_.min_data_in_leaf* sample_data.size()) / dataset->num_data_);
Guolin Ke's avatar
Guolin Ke committed
1010
1011
1012
  // start find bins
  if (num_machines == 1) {
    // if only one machine, find bin locally
1013
    OMP_INIT_EX();
1014
    #pragma omp parallel for schedule(guided)
Guolin Ke's avatar
Guolin Ke committed
1015
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
1016
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1017
      if (ignore_features_.count(i) > 0) {
Guolin Ke's avatar
Guolin Ke committed
1018
        bin_mappers[i] = nullptr;
Guolin Ke's avatar
Guolin Ke committed
1019
1020
        continue;
      }
1021
1022
1023
1024
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(i)) {
        bin_type = BinType::CategoricalBin;
      }
Guolin Ke's avatar
Guolin Ke committed
1025
      bin_mappers[i].reset(new BinMapper());
Belinda Trotta's avatar
Belinda Trotta committed
1026
1027
      if (config_.max_bin_by_feature.empty()) {
        bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
1028
                                sample_data.size(), config_.max_bin, config_.min_data_in_bin,
1029
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
1030
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
1031
1032
      } else {
        bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
1033
                                sample_data.size(), config_.max_bin_by_feature[i],
1034
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
1035
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
1036
      }
1037
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1038
    }
1039
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1040
1041
  } else {
    // start and len will store the process feature indices for different machines
1042
    // machine i will find bins for features in [ start[i], start[i] + len[i] )
Guolin Ke's avatar
Guolin Ke committed
1043
1044
    std::vector<int> start(num_machines);
    std::vector<int> len(num_machines);
1045
    int step = (dataset->num_total_features_ + num_machines - 1) / num_machines;
Guolin Ke's avatar
Guolin Ke committed
1046
1047
1048
1049
    if (step < 1) { step = 1; }

    start[0] = 0;
    for (int i = 0; i < num_machines - 1; ++i) {
1050
      len[i] = std::min(step, dataset->num_total_features_ - start[i]);
Guolin Ke's avatar
Guolin Ke committed
1051
1052
      start[i + 1] = start[i] + len[i];
    }
1053
    len[num_machines - 1] = dataset->num_total_features_ - start[num_machines - 1];
1054
    OMP_INIT_EX();
1055
    #pragma omp parallel for schedule(guided)
1056
    for (int i = 0; i < len[rank]; ++i) {
1057
      OMP_LOOP_EX_BEGIN();
1058
1059
1060
1061
1062
1063
1064
1065
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(start[rank] + i)) {
        bin_type = BinType::CategoricalBin;
      }
      bin_mappers[i].reset(new BinMapper());
Nikita Titov's avatar
Nikita Titov committed
1066
      if (static_cast<int>(sample_values.size()) <= start[rank] + i) {
1067
1068
        continue;
      }
Belinda Trotta's avatar
Belinda Trotta committed
1069
      if (config_.max_bin_by_feature.empty()) {
1070
        bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
Belinda Trotta's avatar
Belinda Trotta committed
1071
                                static_cast<int>(sample_values[start[rank] + i].size()),
1072
                                sample_data.size(), config_.max_bin, config_.min_data_in_bin,
1073
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
1074
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
1075
      } else {
1076
        bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
Belinda Trotta's avatar
Belinda Trotta committed
1077
                                static_cast<int>(sample_values[start[rank] + i].size()),
1078
                                sample_data.size(), config_.max_bin_by_feature[i],
1079
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type,
1080
                                config_.use_missing, config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
1081
      }
1082
      OMP_LOOP_EX_END();
1083
    }
1084
    OMP_THROW_EX();
1085
    comm_size_t self_buf_size = 0;
Guolin Ke's avatar
Guolin Ke committed
1086
    for (int i = 0; i < len[rank]; ++i) {
1087
1088
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
Guolin Ke's avatar
Guolin Ke committed
1089
      }
1090
      self_buf_size += static_cast<comm_size_t>(bin_mappers[i]->SizesInByte());
Guolin Ke's avatar
Guolin Ke committed
1091
    }
1092
1093
    std::vector<char> input_buffer(self_buf_size);
    auto cp_ptr = input_buffer.data();
Guolin Ke's avatar
Guolin Ke committed
1094
    for (int i = 0; i < len[rank]; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1095
1096
1097
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
1098
1099
      bin_mappers[i]->CopyTo(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
1100
1101
      // free
      bin_mappers[i].reset(nullptr);
Guolin Ke's avatar
Guolin Ke committed
1102
    }
1103
1104
1105
1106
    std::vector<comm_size_t> size_len = Network::GlobalArray(self_buf_size);
    std::vector<comm_size_t> size_start(num_machines, 0);
    for (int i = 1; i < num_machines; ++i) {
      size_start[i] = size_start[i - 1] + size_len[i - 1];
Guolin Ke's avatar
Guolin Ke committed
1107
    }
1108
1109
    comm_size_t total_buffer_size = size_start[num_machines - 1] + size_len[num_machines - 1];
    std::vector<char> output_buffer(total_buffer_size);
Guolin Ke's avatar
Guolin Ke committed
1110
    // gather global feature bin mappers
1111
1112
    Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), total_buffer_size);
    cp_ptr = output_buffer.data();
Guolin Ke's avatar
Guolin Ke committed
1113
    // restore features bins from buffer
1114
    for (int i = 0; i < dataset->num_total_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1115
      if (ignore_features_.count(i) > 0) {
Guolin Ke's avatar
Guolin Ke committed
1116
        bin_mappers[i] = nullptr;
Guolin Ke's avatar
Guolin Ke committed
1117
1118
        continue;
      }
Guolin Ke's avatar
Guolin Ke committed
1119
      bin_mappers[i].reset(new BinMapper());
1120
1121
      bin_mappers[i]->CopyFrom(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
Guolin Ke's avatar
Guolin Ke committed
1122
1123
    }
  }
1124
  dataset->Construct(&bin_mappers, dataset->num_total_features_, forced_bin_bounds, Common::Vector2Ptr<int>(&sample_indices).data(),
Guolin Ke's avatar
Guolin Ke committed
1125
                     Common::Vector2Ptr<double>(&sample_values).data(),
1126
                     Common::VectorSize<int>(sample_indices).data(), static_cast<int>(sample_indices.size()), sample_data.size(), config_);
1127
  if (dataset->has_raw()) {
1128
    dataset->ResizeRaw(static_cast<int>(sample_data.size()));
1129
  }
Guolin Ke's avatar
Guolin Ke committed
1130
1131
1132
}

/*! \brief Extract local features from memory */
Guolin Ke's avatar
Guolin Ke committed
1133
void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>* text_data, const Parser* parser, Dataset* dataset) {
Guolin Ke's avatar
Guolin Ke committed
1134
1135
  std::vector<std::pair<int, double>> oneline_features;
  double tmp_label = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
1136
  auto& ref_text_data = *text_data;
1137
  std::vector<float> feature_row(dataset->num_features_);
Guolin Ke's avatar
Guolin Ke committed
1138
  if (predict_fun_ == nullptr) {
1139
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1140
    // if doesn't need to prediction with initial model
1141
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label, feature_row)
Guolin Ke's avatar
Guolin Ke committed
1142
    for (data_size_t i = 0; i < dataset->num_data_; ++i) {
1143
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1144
1145
1146
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
Guolin Ke's avatar
Guolin Ke committed
1147
      parser->ParseOneLine(ref_text_data[i].c_str(), &oneline_features, &tmp_label);
Guolin Ke's avatar
Guolin Ke committed
1148
      // set label
1149
      dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1150
      // free processed line:
Guolin Ke's avatar
Guolin Ke committed
1151
      ref_text_data[i].clear();
Guolin Ke's avatar
Guolin Ke committed
1152
1153
      // shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
      // text_reader_->Lines()[i].shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
1154
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1155
1156
      // push data
      for (auto& inner_data : oneline_features) {
1157
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1158
1159
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1160
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1161
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1162
1163
1164
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
          dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
1165
          if (dataset->has_raw()) {
1166
            feature_row[feature_idx] = static_cast<float>(inner_data.second);
1167
          }
Guolin Ke's avatar
Guolin Ke committed
1168
1169
        } else {
          if (inner_data.first == weight_idx_) {
1170
            dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1171
1172
1173
1174
1175
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
1176
1177
1178
1179
1180
1181
1182
1183
      if (dataset->has_raw()) {
        for (size_t j = 0; j < feature_row.size(); ++j) {
          int feat_ind = dataset->numeric_feature_map_[j];
          if (feat_ind >= 0) {
            dataset->raw_data_[feat_ind][i] = feature_row[j];
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1184
      dataset->FinishOneRow(tid, i, is_feature_added);
1185
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1186
    }
1187
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1188
  } else {
1189
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1190
    // if need to prediction with initial model
1191
    std::vector<double> init_score(dataset->num_data_ * num_class_);
1192
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label, feature_row)
Guolin Ke's avatar
Guolin Ke committed
1193
    for (data_size_t i = 0; i < dataset->num_data_; ++i) {
1194
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1195
1196
1197
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
Guolin Ke's avatar
Guolin Ke committed
1198
      parser->ParseOneLine(ref_text_data[i].c_str(), &oneline_features, &tmp_label);
Guolin Ke's avatar
Guolin Ke committed
1199
      // set initial score
Guolin Ke's avatar
Guolin Ke committed
1200
1201
      std::vector<double> oneline_init_score(num_class_);
      predict_fun_(oneline_features, oneline_init_score.data());
1202
      for (int k = 0; k < num_class_; ++k) {
1203
        init_score[k * dataset->num_data_ + i] = static_cast<double>(oneline_init_score[k]);
Guolin Ke's avatar
Guolin Ke committed
1204
1205
      }
      // set label
1206
      dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1207
      // free processed line:
1208
      ref_text_data[i].clear();
Guolin Ke's avatar
Guolin Ke committed
1209
1210
1211
      // shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
      // text_reader_->Lines()[i].shrink_to_fit();
      // push data
Guolin Ke's avatar
Guolin Ke committed
1212
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1213
      for (auto& inner_data : oneline_features) {
1214
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1215
1216
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1217
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1218
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1219
1220
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
1221
          dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
1222
          if (dataset->has_raw()) {
1223
            feature_row[feature_idx] = static_cast<float>(inner_data.second);
1224
          }
Guolin Ke's avatar
Guolin Ke committed
1225
1226
        } else {
          if (inner_data.first == weight_idx_) {
1227
            dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1228
1229
1230
1231
1232
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1233
      dataset->FinishOneRow(tid, i, is_feature_added);
1234
1235
1236
1237
1238
1239
1240
1241
      if (dataset->has_raw()) {
        for (size_t j = 0; j < feature_row.size(); ++j) {
          int feat_ind = dataset->numeric_feature_map_[j];
          if (feat_ind >= 0) {
            dataset->raw_data_[feat_ind][i] = feature_row[j];
          }
        }
      }
1242
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1243
    }
1244
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1245
    // metadata_ will manage space of init_score
1246
    dataset->metadata_.SetInitScore(init_score.data(), dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1247
  }
Guolin Ke's avatar
Guolin Ke committed
1248
  dataset->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1249
  // text data can be free after loaded feature values
Guolin Ke's avatar
Guolin Ke committed
1250
  text_data->clear();
Guolin Ke's avatar
Guolin Ke committed
1251
1252
1253
}

/*! \brief Extract local features from file */
1254
1255
void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser* parser,
                                            const std::vector<data_size_t>& used_data_indices, Dataset* dataset) {
1256
  std::vector<double> init_score;
Guolin Ke's avatar
Guolin Ke committed
1257
  if (predict_fun_ != nullptr) {
1258
    init_score = std::vector<double>(dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1259
1260
1261
1262
1263
1264
  }
  std::function<void(data_size_t, const std::vector<std::string>&)> process_fun =
    [this, &init_score, &parser, &dataset]
  (data_size_t start_idx, const std::vector<std::string>& lines) {
    std::vector<std::pair<int, double>> oneline_features;
    double tmp_label = 0.0f;
1265
    std::vector<float> feature_row(dataset->num_features_);
1266
    OMP_INIT_EX();
1267
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label, feature_row)
Guolin Ke's avatar
Guolin Ke committed
1268
    for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) {
1269
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1270
1271
1272
1273
1274
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
      parser->ParseOneLine(lines[i].c_str(), &oneline_features, &tmp_label);
      // set initial score
Guolin Ke's avatar
Guolin Ke committed
1275
      if (!init_score.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1276
1277
        std::vector<double> oneline_init_score(num_class_);
        predict_fun_(oneline_features, oneline_init_score.data());
1278
        for (int k = 0; k < num_class_; ++k) {
1279
          init_score[k * dataset->num_data_ + start_idx + i] = static_cast<double>(oneline_init_score[k]);
Guolin Ke's avatar
Guolin Ke committed
1280
1281
1282
        }
      }
      // set label
1283
      dataset->metadata_.SetLabelAt(start_idx + i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1284
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1285
1286
      // push data
      for (auto& inner_data : oneline_features) {
1287
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1288
1289
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1290
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1291
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1292
1293
1294
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
          dataset->feature_groups_[group]->PushData(tid, sub_feature, start_idx + i, inner_data.second);
1295
          if (dataset->has_raw()) {
1296
            feature_row[feature_idx] = static_cast<float>(inner_data.second);
1297
          }
Guolin Ke's avatar
Guolin Ke committed
1298
1299
        } else {
          if (inner_data.first == weight_idx_) {
1300
            dataset->metadata_.SetWeightAt(start_idx + i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1301
1302
1303
1304
1305
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(start_idx + i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
1306
1307
1308
1309
1310
1311
1312
1313
      if (dataset->has_raw()) {
        for (size_t j = 0; j < feature_row.size(); ++j) {
          int feat_ind = dataset->numeric_feature_map_[j];
          if (feat_ind >= 0) {
            dataset->raw_data_[feat_ind][i] = feature_row[j];
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1314
      dataset->FinishOneRow(tid, i, is_feature_added);
1315
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1316
    }
1317
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1318
  };
1319
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
1320
  if (!used_data_indices.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1321
1322
1323
1324
1325
1326
1327
1328
    // only need part of data
    text_reader.ReadPartAndProcessParallel(used_data_indices, process_fun);
  } else {
    // need full data
    text_reader.ReadAllAndProcessParallel(process_fun);
  }

  // metadata_ will manage space of init_score
Guolin Ke's avatar
Guolin Ke committed
1329
  if (!init_score.empty()) {
1330
    dataset->metadata_.SetInitScore(init_score.data(), dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1331
  }
Guolin Ke's avatar
Guolin Ke committed
1332
  dataset->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1333
1334
1335
}

/*! \brief Check can load from binary file */
1336
std::string DatasetLoader::CheckCanLoadFromBin(const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
1337
1338
1339
  std::string bin_filename(filename);
  bin_filename.append(".bin");

1340
  auto reader = VirtualFileReader::Make(bin_filename.c_str());
Guolin Ke's avatar
Guolin Ke committed
1341

1342
  if (!reader->Init()) {
1343
    bin_filename = std::string(filename);
1344
1345
    reader = VirtualFileReader::Make(bin_filename.c_str());
    if (!reader->Init()) {
1346
      Log::Fatal("Cannot open data file %s", bin_filename.c_str());
1347
    }
1348
  }
1349
1350
1351
1352
1353

  size_t buffer_size = 256;
  auto buffer = std::vector<char>(buffer_size);
  // read size of token
  size_t size_of_token = std::strlen(Dataset::binary_file_token);
1354
  size_t read_cnt = reader->Read(buffer.data(), size_of_token);
1355
1356
  if (read_cnt == size_of_token
      && std::string(buffer.data()) == std::string(Dataset::binary_file_token)) {
1357
    return bin_filename;
Guolin Ke's avatar
Guolin Ke committed
1358
  } else {
1359
    return std::string();
Guolin Ke's avatar
Guolin Ke committed
1360
1361
1362
  }
}

1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375


std::vector<std::vector<double>> DatasetLoader::GetForcedBins(std::string forced_bins_path, int num_total_features,
                                                              const std::unordered_set<int>& categorical_features) {
  std::vector<std::vector<double>> forced_bins(num_total_features, std::vector<double>());
  if (forced_bins_path != "") {
    std::ifstream forced_bins_stream(forced_bins_path.c_str());
    if (forced_bins_stream.fail()) {
      Log::Warning("Could not open %s. Will ignore.", forced_bins_path.c_str());
    } else {
      std::stringstream buffer;
      buffer << forced_bins_stream.rdbuf();
      std::string err;
Guolin Ke's avatar
Guolin Ke committed
1376
      Json forced_bins_json = Json::parse(buffer.str(), &err);
1377
1378
1379
1380
      CHECK(forced_bins_json.is_array());
      std::vector<Json> forced_bins_arr = forced_bins_json.array_items();
      for (size_t i = 0; i < forced_bins_arr.size(); ++i) {
        int feature_num = forced_bins_arr[i]["feature"].int_value();
Nikita Titov's avatar
Nikita Titov committed
1381
        CHECK_LT(feature_num, num_total_features);
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
        if (categorical_features.count(feature_num)) {
          Log::Warning("Feature %d is categorical. Will ignore forced bins for this  feature.", feature_num);
        } else {
          std::vector<Json> bounds_arr = forced_bins_arr[i]["bin_upper_bound"].array_items();
          for (size_t j = 0; j < bounds_arr.size(); ++j) {
            forced_bins[feature_num].push_back(bounds_arr[j].number_value());
          }
        }
      }
      // remove duplicates
      for (int i = 0; i < num_total_features; ++i) {
        auto new_end = std::unique(forced_bins[i].begin(), forced_bins[i].end());
        forced_bins[i].erase(new_end, forced_bins[i].end());
      }
    }
  }
  return forced_bins;
}

1401
}  // namespace LightGBM