dataset_loader.cpp 58.6 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
#include <LightGBM/dataset_loader.h>

Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/network.h>
8
#include <LightGBM/utils/array_args.h>
9
#include <LightGBM/utils/json11.h>
10
11
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
12

13
14
#include <fstream>

Guolin Ke's avatar
Guolin Ke committed
15
16
namespace LightGBM {

17
18
using json11::Json;

Guolin Ke's avatar
Guolin Ke committed
19
20
DatasetLoader::DatasetLoader(const Config& io_config, const PredictFunction& predict_fun, int num_class, const char* filename)
  :config_(io_config), random_(config_.data_random_seed), predict_fun_(predict_fun), num_class_(num_class) {
Guolin Ke's avatar
Guolin Ke committed
21
22
23
24
  label_idx_ = 0;
  weight_idx_ = NO_SPECIFIC;
  group_idx_ = NO_SPECIFIC;
  SetHeader(filename);
25
26
27
28
  store_raw_ = false;
  if (io_config.linear_tree) {
    store_raw_ = true;
  }
Guolin Ke's avatar
Guolin Ke committed
29
30
31
32
33
}

DatasetLoader::~DatasetLoader() {
}

Guolin Ke's avatar
Guolin Ke committed
34
void DatasetLoader::SetHeader(const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
35
  std::unordered_map<std::string, int> name2idx;
Guolin Ke's avatar
Guolin Ke committed
36
37
  std::string name_prefix("name:");
  if (filename != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
38
    TextReader<data_size_t> text_reader(filename, config_.header);
Guolin Ke's avatar
Guolin Ke committed
39

Guolin Ke's avatar
Guolin Ke committed
40
    // get column names
Guolin Ke's avatar
Guolin Ke committed
41
    if (config_.header) {
Guolin Ke's avatar
Guolin Ke committed
42
      std::string first_line = text_reader.first_line();
43
      feature_names_ = Common::Split(first_line.c_str(), "\t,");
Guolin Ke's avatar
Guolin Ke committed
44
45
    }

Guolin Ke's avatar
Guolin Ke committed
46
    // load label idx first
Guolin Ke's avatar
Guolin Ke committed
47
48
49
    if (config_.label_column.size() > 0) {
      if (Common::StartsWith(config_.label_column, name_prefix)) {
        std::string name = config_.label_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
50
51
52
53
54
55
56
57
58
59
        label_idx_ = -1;
        for (int i = 0; i < static_cast<int>(feature_names_.size()); ++i) {
          if (name == feature_names_[i]) {
            label_idx_ = i;
            break;
          }
        }
        if (label_idx_ >= 0) {
          Log::Info("Using column %s as label", name.c_str());
        } else {
60
61
          Log::Fatal("Could not find label column %s in data file \n"
                     "or data file doesn't contain header", name.c_str());
Guolin Ke's avatar
Guolin Ke committed
62
        }
Guolin Ke's avatar
Guolin Ke committed
63
      } else {
Guolin Ke's avatar
Guolin Ke committed
64
        if (!Common::AtoiAndCheck(config_.label_column.c_str(), &label_idx_)) {
65
66
67
          Log::Fatal("label_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
68
69
        }
        Log::Info("Using column number %d as label", label_idx_);
Guolin Ke's avatar
Guolin Ke committed
70
71
      }
    }
Guolin Ke's avatar
Guolin Ke committed
72

Guolin Ke's avatar
Guolin Ke committed
73
    if (!feature_names_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
74
75
76
77
      // erase label column name
      feature_names_.erase(feature_names_.begin() + label_idx_);
      for (size_t i = 0; i < feature_names_.size(); ++i) {
        name2idx[feature_names_[i]] = static_cast<int>(i);
Guolin Ke's avatar
Guolin Ke committed
78
      }
Guolin Ke's avatar
Guolin Ke committed
79
80
81
    }

    // load ignore columns
Guolin Ke's avatar
Guolin Ke committed
82
83
84
    if (config_.ignore_column.size() > 0) {
      if (Common::StartsWith(config_.ignore_column, name_prefix)) {
        std::string names = config_.ignore_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
85
86
87
88
89
90
91
92
93
        for (auto name : Common::Split(names.c_str(), ',')) {
          if (name2idx.count(name) > 0) {
            int tmp = name2idx[name];
            ignore_features_.emplace(tmp);
          } else {
            Log::Fatal("Could not find ignore column %s in data file", name.c_str());
          }
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
94
        for (auto token : Common::Split(config_.ignore_column.c_str(), ',')) {
Guolin Ke's avatar
Guolin Ke committed
95
96
          int tmp = 0;
          if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
97
98
99
            Log::Fatal("ignore_column is not a number,\n"
                       "if you want to use a column name,\n"
                       "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
100
101
          }
          ignore_features_.emplace(tmp);
Guolin Ke's avatar
Guolin Ke committed
102
103
104
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
105
    // load weight idx
Guolin Ke's avatar
Guolin Ke committed
106
107
108
    if (config_.weight_column.size() > 0) {
      if (Common::StartsWith(config_.weight_column, name_prefix)) {
        std::string name = config_.weight_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
109
110
111
112
113
114
        if (name2idx.count(name) > 0) {
          weight_idx_ = name2idx[name];
          Log::Info("Using column %s as weight", name.c_str());
        } else {
          Log::Fatal("Could not find weight column %s in data file", name.c_str());
        }
Guolin Ke's avatar
Guolin Ke committed
115
      } else {
Guolin Ke's avatar
Guolin Ke committed
116
        if (!Common::AtoiAndCheck(config_.weight_column.c_str(), &weight_idx_)) {
117
118
119
          Log::Fatal("weight_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
120
121
        }
        Log::Info("Using column number %d as weight", weight_idx_);
Guolin Ke's avatar
Guolin Ke committed
122
      }
Guolin Ke's avatar
Guolin Ke committed
123
      ignore_features_.emplace(weight_idx_);
Guolin Ke's avatar
Guolin Ke committed
124
    }
Guolin Ke's avatar
Guolin Ke committed
125
    // load group idx
Guolin Ke's avatar
Guolin Ke committed
126
127
128
    if (config_.group_column.size() > 0) {
      if (Common::StartsWith(config_.group_column, name_prefix)) {
        std::string name = config_.group_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
129
130
131
132
133
134
135
        if (name2idx.count(name) > 0) {
          group_idx_ = name2idx[name];
          Log::Info("Using column %s as group/query id", name.c_str());
        } else {
          Log::Fatal("Could not find group/query column %s in data file", name.c_str());
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
136
        if (!Common::AtoiAndCheck(config_.group_column.c_str(), &group_idx_)) {
137
138
139
          Log::Fatal("group_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
140
141
142
143
        }
        Log::Info("Using column number %d as group/query id", group_idx_);
      }
      ignore_features_.emplace(group_idx_);
Guolin Ke's avatar
Guolin Ke committed
144
145
    }
  }
Guolin Ke's avatar
Guolin Ke committed
146
147
148
  if (config_.categorical_feature.size() > 0) {
    if (Common::StartsWith(config_.categorical_feature, name_prefix)) {
      std::string names = config_.categorical_feature.substr(name_prefix.size());
149
150
151
152
153
      for (auto name : Common::Split(names.c_str(), ',')) {
        if (name2idx.count(name) > 0) {
          int tmp = name2idx[name];
          categorical_features_.emplace(tmp);
        } else {
Guolin Ke's avatar
Guolin Ke committed
154
          Log::Fatal("Could not find categorical_feature %s in data file", name.c_str());
155
156
157
        }
      }
    } else {
Guolin Ke's avatar
Guolin Ke committed
158
      for (auto token : Common::Split(config_.categorical_feature.c_str(), ',')) {
159
160
        int tmp = 0;
        if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
Guolin Ke's avatar
Guolin Ke committed
161
          Log::Fatal("categorical_feature is not a number,\n"
162
163
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
164
165
166
167
168
        }
        categorical_features_.emplace(tmp);
      }
    }
  }
Guolin Ke's avatar
Guolin Ke committed
169
170
}

171
172
173
174
175
176
177
178
179
180
void CheckSampleSize(size_t sample_cnt, size_t num_data) {
  if (static_cast<double>(sample_cnt) / num_data < 0.2f &&
      sample_cnt < 100000) {
    Log::Warning(
        "Using too small ``bin_construct_sample_cnt`` may encounter "
        "unexpected "
        "errors and poor accuracy.");
  }
}

181
Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_machines) {
Guolin Ke's avatar
Guolin Ke committed
182
  // don't support query id in data file when training in parallel
Guolin Ke's avatar
Guolin Ke committed
183
  if (num_machines > 1 && !config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
184
    if (group_idx_ > 0) {
185
186
      Log::Fatal("Using a query id without pre-partitioning the data file is not supported for parallel training.\n"
                 "Please use an additional query file or pre-partition the data");
Guolin Ke's avatar
Guolin Ke committed
187
188
    }
  }
Guolin Ke's avatar
Guolin Ke committed
189
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
190
191
192
  if (store_raw_) {
    dataset->SetHasRaw(true);
  }
Guolin Ke's avatar
Guolin Ke committed
193
194
  data_size_t num_global_data = 0;
  std::vector<data_size_t> used_data_indices;
195
  auto bin_filename = CheckCanLoadFromBin(filename);
196
  bool is_load_from_binary = false;
197
  if (bin_filename.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
198
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_));
Guolin Ke's avatar
Guolin Ke committed
199
200
201
202
    if (parser == nullptr) {
      Log::Fatal("Could not recognize data format of %s", filename);
    }
    dataset->data_filename_ = filename;
Guolin Ke's avatar
Guolin Ke committed
203
    dataset->label_idx_ = label_idx_;
204
    dataset->metadata_.Init(filename);
Guolin Ke's avatar
Guolin Ke committed
205
    if (!config_.two_round) {
Guolin Ke's avatar
Guolin Ke committed
206
      // read data to memory
207
      auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
208
209
210
      dataset->num_data_ = static_cast<data_size_t>(text_data.size());
      // sample data
      auto sample_data = SampleTextDataFromMemory(text_data);
211
212
      CheckSampleSize(sample_data.size(),
                      static_cast<size_t>(dataset->num_data_));
Guolin Ke's avatar
Guolin Ke committed
213
      // construct feature bin mappers
Guolin Ke's avatar
Guolin Ke committed
214
      ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get());
215
216
217
      if (dataset->has_raw()) {
        dataset->ResizeRaw(dataset->num_data_);
      }
Guolin Ke's avatar
Guolin Ke committed
218
      // initialize label
219
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
220
      // extract features
Guolin Ke's avatar
Guolin Ke committed
221
      ExtractFeaturesFromMemory(&text_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
222
223
224
225
226
227
228
229
230
      text_data.clear();
    } else {
      // sample data from file
      auto sample_data = SampleTextDataFromFile(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
      if (used_data_indices.size() > 0) {
        dataset->num_data_ = static_cast<data_size_t>(used_data_indices.size());
      } else {
        dataset->num_data_ = num_global_data;
      }
231
232
      CheckSampleSize(sample_data.size(),
                      static_cast<size_t>(dataset->num_data_));
Guolin Ke's avatar
Guolin Ke committed
233
      // construct feature bin mappers
Guolin Ke's avatar
Guolin Ke committed
234
      ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get());
235
236
237
      if (dataset->has_raw()) {
        dataset->ResizeRaw(dataset->num_data_);
      }
Guolin Ke's avatar
Guolin Ke committed
238
      // initialize label
239
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
240
      Log::Debug("Making second pass...");
Guolin Ke's avatar
Guolin Ke committed
241
      // extract features
Guolin Ke's avatar
Guolin Ke committed
242
      ExtractFeaturesFromFile(filename, parser.get(), used_data_indices, dataset.get());
Guolin Ke's avatar
Guolin Ke committed
243
244
245
    }
  } else {
    // load data from binary file
246
247
    is_load_from_binary = true;
    Log::Info("Load from binary file %s", bin_filename.c_str());
248
    dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), rank, num_machines, &num_global_data, &used_data_indices));
Guolin Ke's avatar
Guolin Ke committed
249
250
251
252
  }
  // check meta data
  dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);
  // need to check training data
253
254
  CheckDataset(dataset.get(), is_load_from_binary);

Guolin Ke's avatar
Guolin Ke committed
255
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
256
257
258
259
}



260
Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, const Dataset* train_data) {
Guolin Ke's avatar
Guolin Ke committed
261
262
  data_size_t num_global_data = 0;
  std::vector<data_size_t> used_data_indices;
Guolin Ke's avatar
Guolin Ke committed
263
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
264
265
266
  if (store_raw_) {
    dataset->SetHasRaw(true);
  }
267
268
  auto bin_filename = CheckCanLoadFromBin(filename);
  if (bin_filename.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
269
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_));
Guolin Ke's avatar
Guolin Ke committed
270
271
272
273
    if (parser == nullptr) {
      Log::Fatal("Could not recognize data format of %s", filename);
    }
    dataset->data_filename_ = filename;
Guolin Ke's avatar
Guolin Ke committed
274
    dataset->label_idx_ = label_idx_;
275
    dataset->metadata_.Init(filename);
Guolin Ke's avatar
Guolin Ke committed
276
    if (!config_.two_round) {
Guolin Ke's avatar
Guolin Ke committed
277
278
279
280
      // read data in memory
      auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, 0, 1, &num_global_data, &used_data_indices);
      dataset->num_data_ = static_cast<data_size_t>(text_data.size());
      // initialize label
281
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
282
      dataset->CreateValid(train_data);
283
284
285
      if (dataset->has_raw()) {
        dataset->ResizeRaw(dataset->num_data_);
      }
Guolin Ke's avatar
Guolin Ke committed
286
      // extract features
Guolin Ke's avatar
Guolin Ke committed
287
      ExtractFeaturesFromMemory(&text_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
288
289
      text_data.clear();
    } else {
Guolin Ke's avatar
Guolin Ke committed
290
      TextReader<data_size_t> text_reader(filename, config_.header);
Guolin Ke's avatar
Guolin Ke committed
291
292
293
294
      // Get number of lines of data file
      dataset->num_data_ = static_cast<data_size_t>(text_reader.CountLine());
      num_global_data = dataset->num_data_;
      // initialize label
295
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
296
      dataset->CreateValid(train_data);
297
298
299
      if (dataset->has_raw()) {
        dataset->ResizeRaw(dataset->num_data_);
      }
Guolin Ke's avatar
Guolin Ke committed
300
      // extract features
Guolin Ke's avatar
Guolin Ke committed
301
      ExtractFeaturesFromFile(filename, parser.get(), used_data_indices, dataset.get());
Guolin Ke's avatar
Guolin Ke committed
302
303
304
    }
  } else {
    // load data from binary file
305
    dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), 0, 1, &num_global_data, &used_data_indices));
Guolin Ke's avatar
Guolin Ke committed
306
307
308
309
  }
  // not need to check validation data
  // check meta data
  dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
310
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
311
312
}

313
314
315
Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* bin_filename,
                                        int rank, int num_machines, int* num_global_data,
                                        std::vector<data_size_t>* used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
316
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
317
  auto reader = VirtualFileReader::Make(bin_filename);
Guolin Ke's avatar
Guolin Ke committed
318
  dataset->data_filename_ = data_filename;
319
  if (!reader->Init()) {
Guolin Ke's avatar
Guolin Ke committed
320
321
322
323
324
    Log::Fatal("Could not read binary data from %s", bin_filename);
  }

  // buffer to read binary file
  size_t buffer_size = 16 * 1024 * 1024;
Guolin Ke's avatar
Guolin Ke committed
325
  auto buffer = std::vector<char>(buffer_size);
326

327
328
  // check token
  size_t size_of_token = std::strlen(Dataset::binary_file_token);
329
330
331
332
  size_t read_cnt = reader->Read(
      buffer.data(),
      VirtualFileWriter::AlignedSize(sizeof(char) * size_of_token));
  if (read_cnt < sizeof(char) * size_of_token) {
333
334
335
    Log::Fatal("Binary file error: token has the wrong size");
  }
  if (std::string(buffer.data()) != std::string(Dataset::binary_file_token)) {
336
    Log::Fatal("Input file is not LightGBM binary file");
337
  }
Guolin Ke's avatar
Guolin Ke committed
338
339

  // read size of header
340
  read_cnt = reader->Read(buffer.data(), sizeof(size_t));
Guolin Ke's avatar
Guolin Ke committed
341

342
  if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
343
344
345
    Log::Fatal("Binary file error: header has the wrong size");
  }

Guolin Ke's avatar
Guolin Ke committed
346
  size_t size_of_head = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
347
348
349
350

  // re-allocmate space if not enough
  if (size_of_head > buffer_size) {
    buffer_size = size_of_head;
Guolin Ke's avatar
Guolin Ke committed
351
    buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
352
353
  }
  // read header
354
  read_cnt = reader->Read(buffer.data(), size_of_head);
Guolin Ke's avatar
Guolin Ke committed
355
356
357
358
359

  if (read_cnt != size_of_head) {
    Log::Fatal("Binary file error: header is incorrect");
  }
  // get header
Guolin Ke's avatar
Guolin Ke committed
360
  const char* mem_ptr = buffer.data();
Guolin Ke's avatar
Guolin Ke committed
361
  dataset->num_data_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
362
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->num_data_));
Guolin Ke's avatar
Guolin Ke committed
363
  dataset->num_features_ = *(reinterpret_cast<const int*>(mem_ptr));
364
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->num_features_));
Guolin Ke's avatar
Guolin Ke committed
365
  dataset->num_total_features_ = *(reinterpret_cast<const int*>(mem_ptr));
366
367
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(dataset->num_total_features_));
Guolin Ke's avatar
Guolin Ke committed
368
  dataset->label_idx_ = *(reinterpret_cast<const int*>(mem_ptr));
369
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->label_idx_));
370
  dataset->max_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
371
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->max_bin_));
372
  dataset->bin_construct_sample_cnt_ = *(reinterpret_cast<const int*>(mem_ptr));
373
374
  mem_ptr += VirtualFileWriter::AlignedSize(
      sizeof(dataset->bin_construct_sample_cnt_));
375
  dataset->min_data_in_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
376
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->min_data_in_bin_));
377
  dataset->use_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
378
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->use_missing_));
379
  dataset->zero_as_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
380
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->zero_as_missing_));
381
382
  dataset->has_raw_ = *(reinterpret_cast<const bool*>(mem_ptr));
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->has_raw_));
Guolin Ke's avatar
Guolin Ke committed
383
384
  const int* tmp_feature_map = reinterpret_cast<const int*>(mem_ptr);
  dataset->used_feature_map_.clear();
Guolin Ke's avatar
Guolin Ke committed
385
  for (int i = 0; i < dataset->num_total_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
386
387
    dataset->used_feature_map_.push_back(tmp_feature_map[i]);
  }
388
389
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int) *
                                            dataset->num_total_features_);
Guolin Ke's avatar
Guolin Ke committed
390
391
  // num_groups
  dataset->num_groups_ = *(reinterpret_cast<const int*>(mem_ptr));
392
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->num_groups_));
Guolin Ke's avatar
Guolin Ke committed
393
394
395
396
397
398
  // real_feature_idx_
  const int* tmp_ptr_real_feature_idx_ = reinterpret_cast<const int*>(mem_ptr);
  dataset->real_feature_idx_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->real_feature_idx_.push_back(tmp_ptr_real_feature_idx_[i]);
  }
399
400
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * dataset->num_features_);
Guolin Ke's avatar
Guolin Ke committed
401
402
403
404
405
406
  // feature2group
  const int* tmp_ptr_feature2group = reinterpret_cast<const int*>(mem_ptr);
  dataset->feature2group_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->feature2group_.push_back(tmp_ptr_feature2group[i]);
  }
407
408
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * dataset->num_features_);
Guolin Ke's avatar
Guolin Ke committed
409
410
411
412
413
414
  // feature2subfeature
  const int* tmp_ptr_feature2subfeature = reinterpret_cast<const int*>(mem_ptr);
  dataset->feature2subfeature_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->feature2subfeature_.push_back(tmp_ptr_feature2subfeature[i]);
  }
415
416
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * dataset->num_features_);
Guolin Ke's avatar
Guolin Ke committed
417
418
419
420
421
422
423
424
425
426
427
  // group_bin_boundaries
  const uint64_t* tmp_ptr_group_bin_boundaries = reinterpret_cast<const uint64_t*>(mem_ptr);
  dataset->group_bin_boundaries_.clear();
  for (int i = 0; i < dataset->num_groups_ + 1; ++i) {
    dataset->group_bin_boundaries_.push_back(tmp_ptr_group_bin_boundaries[i]);
  }
  mem_ptr += sizeof(uint64_t) * (dataset->num_groups_ + 1);

  // group_feature_start_
  const int* tmp_ptr_group_feature_start = reinterpret_cast<const int*>(mem_ptr);
  dataset->group_feature_start_.clear();
428
  for (int i = 0; i < dataset->num_groups_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
429
430
    dataset->group_feature_start_.push_back(tmp_ptr_group_feature_start[i]);
  }
431
432
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * (dataset->num_groups_));
Guolin Ke's avatar
Guolin Ke committed
433
434
435
436
437
438
439

  // group_feature_cnt_
  const int* tmp_ptr_group_feature_cnt = reinterpret_cast<const int*>(mem_ptr);
  dataset->group_feature_cnt_.clear();
  for (int i = 0; i < dataset->num_groups_; ++i) {
    dataset->group_feature_cnt_.push_back(tmp_ptr_group_feature_cnt[i]);
  }
440
441
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * (dataset->num_groups_));
Guolin Ke's avatar
Guolin Ke committed
442

Belinda Trotta's avatar
Belinda Trotta committed
443
  if (!config_.max_bin_by_feature.empty()) {
444
445
    CHECK_EQ(static_cast<size_t>(dataset->num_total_features_), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
Belinda Trotta's avatar
Belinda Trotta committed
446
447
448
449
450
451
452
453
454
    dataset->max_bin_by_feature_.resize(dataset->num_total_features_);
    dataset->max_bin_by_feature_.assign(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end());
  } else {
    const int32_t* tmp_ptr_max_bin_by_feature = reinterpret_cast<const int32_t*>(mem_ptr);
    dataset->max_bin_by_feature_.clear();
    for (int i = 0; i < dataset->num_total_features_; ++i) {
      dataset->max_bin_by_feature_.push_back(tmp_ptr_max_bin_by_feature[i]);
    }
  }
455
456
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int32_t) *
                                            (dataset->num_total_features_));
Belinda Trotta's avatar
Belinda Trotta committed
457
458
459
460
  if (ArrayArgs<int32_t>::CheckAll(dataset->max_bin_by_feature_, -1)) {
    dataset->max_bin_by_feature_.clear();
  }

Guolin Ke's avatar
Guolin Ke committed
461
  // get feature names
Guolin Ke's avatar
Guolin Ke committed
462
  dataset->feature_names_.clear();
Guolin Ke's avatar
Guolin Ke committed
463
464
465
  // write feature names
  for (int i = 0; i < dataset->num_total_features_; ++i) {
    int str_len = *(reinterpret_cast<const int*>(mem_ptr));
466
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int));
Guolin Ke's avatar
Guolin Ke committed
467
    std::stringstream str_buf;
468
    auto tmp_arr = reinterpret_cast<const char*>(mem_ptr);
Guolin Ke's avatar
Guolin Ke committed
469
    for (int j = 0; j < str_len; ++j) {
470
      char tmp_char = tmp_arr[j];
Guolin Ke's avatar
Guolin Ke committed
471
472
      str_buf << tmp_char;
    }
473
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(char) * str_len);
Guolin Ke's avatar
Guolin Ke committed
474
    dataset->feature_names_.emplace_back(str_buf.str());
Guolin Ke's avatar
Guolin Ke committed
475
  }
476
477
478
479
  // get forced_bin_bounds_
  dataset->forced_bin_bounds_ = std::vector<std::vector<double>>(dataset->num_total_features_, std::vector<double>());
  for (int i = 0; i < dataset->num_total_features_; ++i) {
    int num_bounds = *(reinterpret_cast<const int*>(mem_ptr));
480
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int));
481
    dataset->forced_bin_bounds_[i] = std::vector<double>();
482
483
    const double* tmp_ptr_forced_bounds =
        reinterpret_cast<const double*>(mem_ptr);
484
485
486
487
488
489
    for (int j = 0; j < num_bounds; ++j) {
      double bound = tmp_ptr_forced_bounds[j];
      dataset->forced_bin_bounds_[i].push_back(bound);
    }
    mem_ptr += num_bounds * sizeof(double);
  }
Guolin Ke's avatar
Guolin Ke committed
490
491

  // read size of meta data
492
  read_cnt = reader->Read(buffer.data(), sizeof(size_t));
Guolin Ke's avatar
Guolin Ke committed
493

494
  if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
495
496
497
    Log::Fatal("Binary file error: meta data has the wrong size");
  }

Guolin Ke's avatar
Guolin Ke committed
498
  size_t size_of_metadata = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
499
500
501
502

  // re-allocate space if not enough
  if (size_of_metadata > buffer_size) {
    buffer_size = size_of_metadata;
Guolin Ke's avatar
Guolin Ke committed
503
    buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
504
505
  }
  //  read meta data
506
  read_cnt = reader->Read(buffer.data(), size_of_metadata);
Guolin Ke's avatar
Guolin Ke committed
507
508
509
510
511

  if (read_cnt != size_of_metadata) {
    Log::Fatal("Binary file error: meta data is incorrect");
  }
  // load meta data
Guolin Ke's avatar
Guolin Ke committed
512
  dataset->metadata_.LoadFromMemory(buffer.data());
Guolin Ke's avatar
Guolin Ke committed
513

514
515
  *num_global_data = dataset->num_data_;
  used_data_indices->clear();
Guolin Ke's avatar
Guolin Ke committed
516
  // sample local used data if need to partition
Guolin Ke's avatar
Guolin Ke committed
517
  if (num_machines > 1 && !config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
518
519
520
521
    const data_size_t* query_boundaries = dataset->metadata_.query_boundaries();
    if (query_boundaries == nullptr) {
      // if not contain query file, minimal sample unit is one record
      for (data_size_t i = 0; i < dataset->num_data_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
522
        if (random_.NextShort(0, num_machines) == rank) {
523
          used_data_indices->push_back(i);
Guolin Ke's avatar
Guolin Ke committed
524
525
526
527
528
529
530
531
532
        }
      }
    } else {
      // if contain query file, minimal sample unit is one query
      data_size_t num_queries = dataset->metadata_.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      for (data_size_t i = 0; i < dataset->num_data_; ++i) {
        if (qid >= num_queries) {
533
534
          Log::Fatal("Current query exceeds the range of the query file,\n"
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
535
536
537
538
        }
        if (i >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
539
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
540
541
542
543
544
            is_query_used = true;
          }
          ++qid;
        }
        if (is_query_used) {
545
          used_data_indices->push_back(i);
Guolin Ke's avatar
Guolin Ke committed
546
547
548
        }
      }
    }
549
    dataset->num_data_ = static_cast<data_size_t>((*used_data_indices).size());
Guolin Ke's avatar
Guolin Ke committed
550
  }
551
  dataset->metadata_.PartitionLabel(*used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
552
  // read feature data
Guolin Ke's avatar
Guolin Ke committed
553
  for (int i = 0; i < dataset->num_groups_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
554
    // read feature size
555
556
    read_cnt = reader->Read(buffer.data(), sizeof(size_t));
    if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
557
558
      Log::Fatal("Binary file error: feature %d has the wrong size", i);
    }
Guolin Ke's avatar
Guolin Ke committed
559
    size_t size_of_feature = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
560
561
562
    // re-allocate space if not enough
    if (size_of_feature > buffer_size) {
      buffer_size = size_of_feature;
Guolin Ke's avatar
Guolin Ke committed
563
      buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
564
565
    }

566
    read_cnt = reader->Read(buffer.data(), size_of_feature);
Guolin Ke's avatar
Guolin Ke committed
567
568
569
570

    if (read_cnt != size_of_feature) {
      Log::Fatal("Binary file error: feature %d is incorrect, read count: %d", i, read_cnt);
    }
Guolin Ke's avatar
Guolin Ke committed
571
    dataset->feature_groups_.emplace_back(std::unique_ptr<FeatureGroup>(
572
573
      new FeatureGroup(buffer.data(),
                       *num_global_data,
574
                       *used_data_indices, i)));
Guolin Ke's avatar
Guolin Ke committed
575
  }
Guolin Ke's avatar
Guolin Ke committed
576
  dataset->feature_groups_.shrink_to_fit();
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613

  // raw data
  dataset->numeric_feature_map_ = std::vector<int>(dataset->num_features_, false);
  dataset->num_numeric_features_ = 0;
  for (int i = 0; i < dataset->num_features_; ++i) {
    if (dataset->FeatureBinMapper(i)->bin_type() == BinType::CategoricalBin) {
      dataset->numeric_feature_map_[i] = -1;
    } else {
      dataset->numeric_feature_map_[i] = dataset->num_numeric_features_;
      ++dataset->num_numeric_features_;
    }
  }
  if (dataset->has_raw()) {
    dataset->ResizeRaw(dataset->num_data());
      size_t row_size = dataset->num_numeric_features_ * sizeof(float);
      if (row_size > buffer_size) {
        buffer_size = row_size;
        buffer.resize(buffer_size);
      }
    for (int i = 0; i < dataset->num_data(); ++i) {
      read_cnt = reader->Read(buffer.data(), row_size);
      if (read_cnt != row_size) {
        Log::Fatal("Binary file error: row %d of raw data is incorrect, read count: %d", i, read_cnt);
      }
      mem_ptr = buffer.data();
      const float* tmp_ptr_raw_row = reinterpret_cast<const float*>(mem_ptr);
      std::vector<float> curr_row(dataset->num_numeric_features_, 0);
      for (int j = 0; j < dataset->num_features(); ++j) {
        int feat_ind = dataset->numeric_feature_map_[j];
        if (feat_ind >= 0) {
          dataset->raw_data_[feat_ind][i] = tmp_ptr_raw_row[feat_ind];
        }
      }
      mem_ptr += row_size;
    }
  }

Guolin Ke's avatar
Guolin Ke committed
614
  dataset->is_finish_load_ = true;
Guolin Ke's avatar
Guolin Ke committed
615
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
616
617
}

618

619
620
621
Dataset* DatasetLoader::ConstructFromSampleData(double** sample_values,
                                                int** sample_indices, int num_col, const int* num_per_col,
                                                size_t total_sample_size, data_size_t num_data) {
622
  CheckSampleSize(total_sample_size, static_cast<size_t>(num_data));
623
624
625
626
627
  int num_total_features = num_col;
  if (Network::num_machines() > 1) {
    num_total_features = Network::GlobalSyncUpByMax(num_total_features);
  }
  std::vector<std::unique_ptr<BinMapper>> bin_mappers(num_total_features);
628
629
  // fill feature_names_ if not header
  if (feature_names_.empty()) {
630
    for (int i = 0; i < num_col; ++i) {
631
632
633
634
635
      std::stringstream str_buf;
      str_buf << "Column_" << i;
      feature_names_.push_back(str_buf.str());
    }
  }
Belinda Trotta's avatar
Belinda Trotta committed
636
  if (!config_.max_bin_by_feature.empty()) {
637
638
    CHECK_EQ(static_cast<size_t>(num_col), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
Belinda Trotta's avatar
Belinda Trotta committed
639
  }
640
641
642
643
644

  // get forced split
  std::string forced_bins_path = config_.forcedbins_filename;
  std::vector<std::vector<double>> forced_bin_bounds = DatasetLoader::GetForcedBins(forced_bins_path, num_col, categorical_features_);

Guolin Ke's avatar
Guolin Ke committed
645
  const data_size_t filter_cnt = static_cast<data_size_t>(
Guolin Ke's avatar
Guolin Ke committed
646
    static_cast<double>(config_.min_data_in_leaf * total_sample_size) / num_data);
647
648
649
650
651
652
653
654
655
656
657
658
659
  if (Network::num_machines() == 1) {
    // if only one machine, find bin locally
    OMP_INIT_EX();
    #pragma omp parallel for schedule(guided)
    for (int i = 0; i < num_col; ++i) {
      OMP_LOOP_EX_BEGIN();
      if (ignore_features_.count(i) > 0) {
        bin_mappers[i] = nullptr;
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(i)) {
        bin_type = BinType::CategoricalBin;
660
661
662
663
        bool feat_is_unconstrained = ((config_.monotone_constraints.size() == 0) || (config_.monotone_constraints[i] == 0));
        if (!feat_is_unconstrained) {
            Log::Fatal("The output cannot be monotone with respect to categorical features");
        }
664
665
      }
      bin_mappers[i].reset(new BinMapper());
Belinda Trotta's avatar
Belinda Trotta committed
666
667
      if (config_.max_bin_by_feature.empty()) {
        bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
668
                                config_.max_bin, config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter,
669
670
                                bin_type, config_.use_missing, config_.zero_as_missing,
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
671
672
      } else {
        bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
673
                                config_.max_bin_by_feature[i], config_.min_data_in_bin,
674
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
675
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
676
      }
677
678
679
680
681
682
683
684
685
686
687
688
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();
  } else {
    // if have multi-machines, need to find bin distributed
    // different machines will find bin for different features
    int num_machines = Network::num_machines();
    int rank = Network::rank();
    // start and len will store the process feature indices for different machines
    // machine i will find bins for features in [ start[i], start[i] + len[i] )
    std::vector<int> start(num_machines);
    std::vector<int> len(num_machines);
689
    int step = (num_total_features + num_machines - 1) / num_machines;
690
691
692
693
    if (step < 1) { step = 1; }

    start[0] = 0;
    for (int i = 0; i < num_machines - 1; ++i) {
694
      len[i] = std::min(step, num_total_features - start[i]);
695
696
      start[i + 1] = start[i] + len[i];
    }
697
    len[num_machines - 1] = num_total_features - start[num_machines - 1];
698
699
700
701
702
703
704
705
706
707
708
709
    OMP_INIT_EX();
    #pragma omp parallel for schedule(guided)
    for (int i = 0; i < len[rank]; ++i) {
      OMP_LOOP_EX_BEGIN();
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(start[rank] + i)) {
        bin_type = BinType::CategoricalBin;
      }
      bin_mappers[i].reset(new BinMapper());
710
711
712
      if (num_col <= start[rank] + i) {
        continue;
      }
Belinda Trotta's avatar
Belinda Trotta committed
713
      if (config_.max_bin_by_feature.empty()) {
714
715
        bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
                                total_sample_size, config_.max_bin, config_.min_data_in_bin,
716
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
717
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
718
      } else {
719
720
        bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
                                total_sample_size, config_.max_bin_by_feature[start[rank] + i],
721
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
722
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
723
      }
724
725
      OMP_LOOP_EX_END();
    }
Guolin Ke's avatar
Guolin Ke committed
726
    OMP_THROW_EX();
727
    comm_size_t self_buf_size = 0;
728
    for (int i = 0; i < len[rank]; ++i) {
729
730
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
731
      }
732
      self_buf_size += static_cast<comm_size_t>(bin_mappers[i]->SizesInByte());
Guolin Ke's avatar
Guolin Ke committed
733
    }
734
735
    std::vector<char> input_buffer(self_buf_size);
    auto cp_ptr = input_buffer.data();
736
737
738
739
    for (int i = 0; i < len[rank]; ++i) {
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
740
741
      bin_mappers[i]->CopyTo(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
742
743
744
      // free
      bin_mappers[i].reset(nullptr);
    }
745
746
747
748
    std::vector<comm_size_t> size_len = Network::GlobalArray(self_buf_size);
    std::vector<comm_size_t> size_start(num_machines, 0);
    for (int i = 1; i < num_machines; ++i) {
      size_start[i] = size_start[i - 1] + size_len[i - 1];
749
    }
750
751
    comm_size_t total_buffer_size = size_start[num_machines - 1] + size_len[num_machines - 1];
    std::vector<char> output_buffer(total_buffer_size);
752
    // gather global feature bin mappers
753
754
    Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), total_buffer_size);
    cp_ptr = output_buffer.data();
755
    // restore features bins from buffer
756
    for (int i = 0; i < num_total_features; ++i) {
757
758
759
760
761
      if (ignore_features_.count(i) > 0) {
        bin_mappers[i] = nullptr;
        continue;
      }
      bin_mappers[i].reset(new BinMapper());
762
763
      bin_mappers[i]->CopyFrom(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
764
    }
Guolin Ke's avatar
Guolin Ke committed
765
  }
Guolin Ke's avatar
Guolin Ke committed
766
  auto dataset = std::unique_ptr<Dataset>(new Dataset(num_data));
Guolin Ke's avatar
Guolin Ke committed
767
  dataset->Construct(&bin_mappers, num_total_features, forced_bin_bounds, sample_indices, sample_values, num_per_col, num_col, total_sample_size, config_);
768
769
770
  if (dataset->has_raw()) {
    dataset->ResizeRaw(num_data);
  }
771
  dataset->set_feature_names(feature_names_);
Guolin Ke's avatar
Guolin Ke committed
772
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
773
}
Guolin Ke's avatar
Guolin Ke committed
774
775
776
777


// ---- private functions ----

778
void DatasetLoader::CheckDataset(const Dataset* dataset, bool is_load_from_binary) {
Guolin Ke's avatar
Guolin Ke committed
779
  if (dataset->num_data_ <= 0) {
Guolin Ke's avatar
Guolin Ke committed
780
    Log::Fatal("Data file %s is empty", dataset->data_filename_.c_str());
Guolin Ke's avatar
Guolin Ke committed
781
  }
782
783
  if (dataset->feature_names_.size() != static_cast<size_t>(dataset->num_total_features_)) {
    Log::Fatal("Size of feature name error, should be %d, got %d", dataset->num_total_features_,
784
               static_cast<int>(dataset->feature_names_.size()));
785
  }
Guolin Ke's avatar
Guolin Ke committed
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
  bool is_feature_order_by_group = true;
  int last_group = -1;
  int last_sub_feature = -1;
  // if features are ordered, not need to use hist_buf
  for (int i = 0; i < dataset->num_features_; ++i) {
    int group = dataset->feature2group_[i];
    int sub_feature = dataset->feature2subfeature_[i];
    if (group < last_group) {
      is_feature_order_by_group = false;
    } else if (group == last_group) {
      if (sub_feature <= last_sub_feature) {
        is_feature_order_by_group = false;
        break;
      }
    }
    last_group = group;
    last_sub_feature = sub_feature;
  }
  if (!is_feature_order_by_group) {
805
    Log::Fatal("Features in dataset should be ordered by group");
Guolin Ke's avatar
Guolin Ke committed
806
  }
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832

  if (is_load_from_binary) {
    if (dataset->max_bin_ != config_.max_bin) {
      Log::Fatal("Dataset max_bin %d != config %d", dataset->max_bin_, config_.max_bin);
    }
    if (dataset->min_data_in_bin_ != config_.min_data_in_bin) {
      Log::Fatal("Dataset min_data_in_bin %d != config %d", dataset->min_data_in_bin_, config_.min_data_in_bin);
    }
    if (dataset->use_missing_ != config_.use_missing) {
      Log::Fatal("Dataset use_missing %d != config %d", dataset->use_missing_, config_.use_missing);
    }
    if (dataset->zero_as_missing_ != config_.zero_as_missing) {
      Log::Fatal("Dataset zero_as_missing %d != config %d", dataset->zero_as_missing_, config_.zero_as_missing);
    }
    if (dataset->bin_construct_sample_cnt_ != config_.bin_construct_sample_cnt) {
      Log::Fatal("Dataset bin_construct_sample_cnt %d != config %d", dataset->bin_construct_sample_cnt_, config_.bin_construct_sample_cnt);
    }
    if ((dataset->max_bin_by_feature_.size() != config_.max_bin_by_feature.size()) ||
        !std::equal(dataset->max_bin_by_feature_.begin(), dataset->max_bin_by_feature_.end(),
            config_.max_bin_by_feature.begin())) {
      Log::Fatal("Dataset max_bin_by_feature does not match with config");
    }

    int label_idx = -1;
    if (Common::AtoiAndCheck(config_.label_column.c_str(), &label_idx)) {
      if (dataset->label_idx_ != label_idx) {
833
        Log::Fatal("Dataset label_idx %d != config %d", dataset->label_idx_, label_idx);
834
835
836
837
838
      }
    } else {
      Log::Info("Recommend use integer for label index when loading data from binary for sanity check.");
    }
  }
Guolin Ke's avatar
Guolin Ke committed
839
840
841
}

std::vector<std::string> DatasetLoader::LoadTextDataToMemory(const char* filename, const Metadata& metadata,
842
843
                                                             int rank, int num_machines, int* num_global_data,
                                                             std::vector<data_size_t>* used_data_indices) {
844
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
845
  used_data_indices->clear();
Guolin Ke's avatar
Guolin Ke committed
846
  if (num_machines == 1 || config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
847
848
849
850
851
852
853
854
855
    // read all lines
    *num_global_data = text_reader.ReadAllLines();
  } else {  // need partition data
            // get query data
    const data_size_t* query_boundaries = metadata.query_boundaries();

    if (query_boundaries == nullptr) {
      // if not contain query data, minimal sample unit is one record
      *num_global_data = text_reader.ReadAndFilterLines([this, rank, num_machines](data_size_t) {
Guolin Ke's avatar
Guolin Ke committed
856
        if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
857
858
859
860
861
862
863
864
865
866
867
868
869
870
          return true;
        } else {
          return false;
        }
      }, used_data_indices);
    } else {
      // if contain query data, minimal sample unit is one query
      data_size_t num_queries = metadata.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      *num_global_data = text_reader.ReadAndFilterLines(
        [this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries]
      (data_size_t line_idx) {
        if (qid >= num_queries) {
871
872
          Log::Fatal("Current query exceeds the range of the query file,\n"
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
873
874
875
876
        }
        if (line_idx >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
877
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
878
879
880
881
882
883
884
885
886
887
888
889
            is_query_used = true;
          }
          ++qid;
        }
        return is_query_used;
      }, used_data_indices);
    }
  }
  return std::move(text_reader.Lines());
}

std::vector<std::string> DatasetLoader::SampleTextDataFromMemory(const std::vector<std::string>& data) {
Guolin Ke's avatar
Guolin Ke committed
890
  int sample_cnt = config_.bin_construct_sample_cnt;
891
892
  if (static_cast<size_t>(sample_cnt) > data.size()) {
    sample_cnt = static_cast<int>(data.size());
893
  }
894
  auto sample_indices = random_.Sample(static_cast<int>(data.size()), sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
895
  std::vector<std::string> out(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
896
897
  for (size_t i = 0; i < sample_indices.size(); ++i) {
    const size_t idx = sample_indices[i];
Guolin Ke's avatar
Guolin Ke committed
898
    out[i] = data[idx];
Guolin Ke's avatar
Guolin Ke committed
899
900
901
902
  }
  return out;
}

903
904
905
std::vector<std::string> DatasetLoader::SampleTextDataFromFile(const char* filename, const Metadata& metadata,
                                                               int rank, int num_machines, int* num_global_data,
                                                               std::vector<data_size_t>* used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
906
  const data_size_t sample_cnt = static_cast<data_size_t>(config_.bin_construct_sample_cnt);
907
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
908
  std::vector<std::string> out_data;
Guolin Ke's avatar
Guolin Ke committed
909
  if (num_machines == 1 || config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
910
    *num_global_data = static_cast<data_size_t>(text_reader.SampleFromFile(&random_, sample_cnt, &out_data));
Guolin Ke's avatar
Guolin Ke committed
911
912
913
914
915
916
917
  } else {  // need partition data
            // get query data
    const data_size_t* query_boundaries = metadata.query_boundaries();
    if (query_boundaries == nullptr) {
      // if not contain query file, minimal sample unit is one record
      *num_global_data = text_reader.SampleAndFilterFromFile([this, rank, num_machines]
      (data_size_t) {
Guolin Ke's avatar
Guolin Ke committed
918
        if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
919
920
921
922
          return true;
        } else {
          return false;
        }
Guolin Ke's avatar
Guolin Ke committed
923
      }, used_data_indices, &random_, sample_cnt, &out_data);
Guolin Ke's avatar
Guolin Ke committed
924
925
926
927
928
929
930
931
932
    } else {
      // if contain query file, minimal sample unit is one query
      data_size_t num_queries = metadata.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      *num_global_data = text_reader.SampleAndFilterFromFile(
        [this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries]
      (data_size_t line_idx) {
        if (qid >= num_queries) {
933
934
          Log::Fatal("Query id exceeds the range of the query file, "
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
935
936
937
938
        }
        if (line_idx >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
939
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
940
941
942
943
944
            is_query_used = true;
          }
          ++qid;
        }
        return is_query_used;
Guolin Ke's avatar
Guolin Ke committed
945
      }, used_data_indices, &random_, sample_cnt, &out_data);
Guolin Ke's avatar
Guolin Ke committed
946
947
948
949
950
    }
  }
  return out_data;
}

951
952
953
void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
                                                    const std::vector<std::string>& sample_data,
                                                    const Parser* parser, Dataset* dataset) {
Guolin Ke's avatar
Guolin Ke committed
954
  std::vector<std::vector<double>> sample_values;
Guolin Ke's avatar
Guolin Ke committed
955
  std::vector<std::vector<int>> sample_indices;
Guolin Ke's avatar
Guolin Ke committed
956
957
  std::vector<std::pair<int, double>> oneline_features;
  double label;
Guolin Ke's avatar
Guolin Ke committed
958
  for (int i = 0; i < static_cast<int>(sample_data.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
959
960
961
962
    oneline_features.clear();
    // parse features
    parser->ParseOneLine(sample_data[i].c_str(), &oneline_features, &label);
    for (std::pair<int, double>& inner_data : oneline_features) {
963
      if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
Guolin Ke's avatar
Guolin Ke committed
964
965
        sample_values.resize(inner_data.first + 1);
        sample_indices.resize(inner_data.first + 1);
966
      }
Guolin Ke's avatar
Guolin Ke committed
967
      if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
Guolin Ke's avatar
Guolin Ke committed
968
969
        sample_values[inner_data.first].emplace_back(inner_data.second);
        sample_indices[inner_data.first].emplace_back(i);
Guolin Ke's avatar
Guolin Ke committed
970
971
972
973
      }
    }
  }

Guolin Ke's avatar
Guolin Ke committed
974
  dataset->feature_groups_.clear();
975
976
977
978
979
  dataset->num_total_features_ = std::max(static_cast<int>(sample_values.size()), parser->NumFeatures());
  if (num_machines > 1) {
    dataset->num_total_features_ = Network::GlobalSyncUpByMax(dataset->num_total_features_);
  }
  if (!feature_names_.empty()) {
980
    CHECK_EQ(dataset->num_total_features_, static_cast<int>(feature_names_.size()));
981
  }
Guolin Ke's avatar
Guolin Ke committed
982

Belinda Trotta's avatar
Belinda Trotta committed
983
  if (!config_.max_bin_by_feature.empty()) {
984
985
    CHECK_EQ(static_cast<size_t>(dataset->num_total_features_), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
Belinda Trotta's avatar
Belinda Trotta committed
986
987
  }

988
989
  // get forced split
  std::string forced_bins_path = config_.forcedbins_filename;
990
991
  std::vector<std::vector<double>> forced_bin_bounds = DatasetLoader::GetForcedBins(forced_bins_path,
                                                                                    dataset->num_total_features_,
992
993
                                                                                    categorical_features_);

Guolin Ke's avatar
Guolin Ke committed
994
995
996
997
998
999
  // check the range of label_idx, weight_idx and group_idx
  CHECK(label_idx_ >= 0 && label_idx_ <= dataset->num_total_features_);
  CHECK(weight_idx_ < 0 || weight_idx_ < dataset->num_total_features_);
  CHECK(group_idx_ < 0 || group_idx_ < dataset->num_total_features_);

  // fill feature_names_ if not header
Guolin Ke's avatar
Guolin Ke committed
1000
  if (feature_names_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1001
1002
1003
1004
1005
1006
    for (int i = 0; i < dataset->num_total_features_; ++i) {
      std::stringstream str_buf;
      str_buf << "Column_" << i;
      feature_names_.push_back(str_buf.str());
    }
  }
1007
  dataset->set_feature_names(feature_names_);
Guolin Ke's avatar
Guolin Ke committed
1008
  std::vector<std::unique_ptr<BinMapper>> bin_mappers(dataset->num_total_features_);
Guolin Ke's avatar
Guolin Ke committed
1009
  const data_size_t filter_cnt = static_cast<data_size_t>(
Guolin Ke's avatar
Guolin Ke committed
1010
    static_cast<double>(config_.min_data_in_leaf* sample_data.size()) / dataset->num_data_);
Guolin Ke's avatar
Guolin Ke committed
1011
1012
1013
  // start find bins
  if (num_machines == 1) {
    // if only one machine, find bin locally
1014
    OMP_INIT_EX();
1015
    #pragma omp parallel for schedule(guided)
Guolin Ke's avatar
Guolin Ke committed
1016
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
1017
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1018
      if (ignore_features_.count(i) > 0) {
Guolin Ke's avatar
Guolin Ke committed
1019
        bin_mappers[i] = nullptr;
Guolin Ke's avatar
Guolin Ke committed
1020
1021
        continue;
      }
1022
1023
1024
1025
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(i)) {
        bin_type = BinType::CategoricalBin;
      }
Guolin Ke's avatar
Guolin Ke committed
1026
      bin_mappers[i].reset(new BinMapper());
Belinda Trotta's avatar
Belinda Trotta committed
1027
1028
      if (config_.max_bin_by_feature.empty()) {
        bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
1029
                                sample_data.size(), config_.max_bin, config_.min_data_in_bin,
1030
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
1031
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
1032
1033
      } else {
        bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
1034
                                sample_data.size(), config_.max_bin_by_feature[i],
1035
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
1036
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
1037
      }
1038
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1039
    }
1040
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1041
1042
  } else {
    // start and len will store the process feature indices for different machines
1043
    // machine i will find bins for features in [ start[i], start[i] + len[i] )
Guolin Ke's avatar
Guolin Ke committed
1044
1045
    std::vector<int> start(num_machines);
    std::vector<int> len(num_machines);
1046
    int step = (dataset->num_total_features_ + num_machines - 1) / num_machines;
Guolin Ke's avatar
Guolin Ke committed
1047
1048
1049
1050
    if (step < 1) { step = 1; }

    start[0] = 0;
    for (int i = 0; i < num_machines - 1; ++i) {
1051
      len[i] = std::min(step, dataset->num_total_features_ - start[i]);
Guolin Ke's avatar
Guolin Ke committed
1052
1053
      start[i + 1] = start[i] + len[i];
    }
1054
    len[num_machines - 1] = dataset->num_total_features_ - start[num_machines - 1];
1055
    OMP_INIT_EX();
1056
    #pragma omp parallel for schedule(guided)
1057
    for (int i = 0; i < len[rank]; ++i) {
1058
      OMP_LOOP_EX_BEGIN();
1059
1060
1061
1062
1063
1064
1065
1066
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(start[rank] + i)) {
        bin_type = BinType::CategoricalBin;
      }
      bin_mappers[i].reset(new BinMapper());
Nikita Titov's avatar
Nikita Titov committed
1067
      if (static_cast<int>(sample_values.size()) <= start[rank] + i) {
1068
1069
        continue;
      }
Belinda Trotta's avatar
Belinda Trotta committed
1070
      if (config_.max_bin_by_feature.empty()) {
1071
        bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
Belinda Trotta's avatar
Belinda Trotta committed
1072
                                static_cast<int>(sample_values[start[rank] + i].size()),
1073
                                sample_data.size(), config_.max_bin, config_.min_data_in_bin,
1074
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
1075
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
1076
      } else {
1077
        bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
Belinda Trotta's avatar
Belinda Trotta committed
1078
                                static_cast<int>(sample_values[start[rank] + i].size()),
1079
                                sample_data.size(), config_.max_bin_by_feature[i],
1080
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type,
1081
                                config_.use_missing, config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
1082
      }
1083
      OMP_LOOP_EX_END();
1084
    }
1085
    OMP_THROW_EX();
1086
    comm_size_t self_buf_size = 0;
Guolin Ke's avatar
Guolin Ke committed
1087
    for (int i = 0; i < len[rank]; ++i) {
1088
1089
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
Guolin Ke's avatar
Guolin Ke committed
1090
      }
1091
      self_buf_size += static_cast<comm_size_t>(bin_mappers[i]->SizesInByte());
Guolin Ke's avatar
Guolin Ke committed
1092
    }
1093
1094
    std::vector<char> input_buffer(self_buf_size);
    auto cp_ptr = input_buffer.data();
Guolin Ke's avatar
Guolin Ke committed
1095
    for (int i = 0; i < len[rank]; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1096
1097
1098
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
1099
1100
      bin_mappers[i]->CopyTo(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
1101
1102
      // free
      bin_mappers[i].reset(nullptr);
Guolin Ke's avatar
Guolin Ke committed
1103
    }
1104
1105
1106
1107
    std::vector<comm_size_t> size_len = Network::GlobalArray(self_buf_size);
    std::vector<comm_size_t> size_start(num_machines, 0);
    for (int i = 1; i < num_machines; ++i) {
      size_start[i] = size_start[i - 1] + size_len[i - 1];
Guolin Ke's avatar
Guolin Ke committed
1108
    }
1109
1110
    comm_size_t total_buffer_size = size_start[num_machines - 1] + size_len[num_machines - 1];
    std::vector<char> output_buffer(total_buffer_size);
Guolin Ke's avatar
Guolin Ke committed
1111
    // gather global feature bin mappers
1112
1113
    Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), total_buffer_size);
    cp_ptr = output_buffer.data();
Guolin Ke's avatar
Guolin Ke committed
1114
    // restore features bins from buffer
1115
    for (int i = 0; i < dataset->num_total_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1116
      if (ignore_features_.count(i) > 0) {
Guolin Ke's avatar
Guolin Ke committed
1117
        bin_mappers[i] = nullptr;
Guolin Ke's avatar
Guolin Ke committed
1118
1119
        continue;
      }
Guolin Ke's avatar
Guolin Ke committed
1120
      bin_mappers[i].reset(new BinMapper());
1121
1122
      bin_mappers[i]->CopyFrom(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
Guolin Ke's avatar
Guolin Ke committed
1123
1124
    }
  }
1125
  dataset->Construct(&bin_mappers, dataset->num_total_features_, forced_bin_bounds, Common::Vector2Ptr<int>(&sample_indices).data(),
Guolin Ke's avatar
Guolin Ke committed
1126
                     Common::Vector2Ptr<double>(&sample_values).data(),
1127
                     Common::VectorSize<int>(sample_indices).data(), static_cast<int>(sample_indices.size()), sample_data.size(), config_);
1128
1129
1130
  if (dataset->has_raw()) {
    dataset->ResizeRaw(sample_data.size());
  }
Guolin Ke's avatar
Guolin Ke committed
1131
1132
1133
}

/*! \brief Extract local features from memory */
Guolin Ke's avatar
Guolin Ke committed
1134
void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>* text_data, const Parser* parser, Dataset* dataset) {
Guolin Ke's avatar
Guolin Ke committed
1135
1136
  std::vector<std::pair<int, double>> oneline_features;
  double tmp_label = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
1137
  auto& ref_text_data = *text_data;
1138
  std::vector<float> feature_row(dataset->num_features_);
Guolin Ke's avatar
Guolin Ke committed
1139
  if (predict_fun_ == nullptr) {
1140
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1141
    // if doesn't need to prediction with initial model
1142
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label, feature_row)
Guolin Ke's avatar
Guolin Ke committed
1143
    for (data_size_t i = 0; i < dataset->num_data_; ++i) {
1144
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1145
1146
1147
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
Guolin Ke's avatar
Guolin Ke committed
1148
      parser->ParseOneLine(ref_text_data[i].c_str(), &oneline_features, &tmp_label);
Guolin Ke's avatar
Guolin Ke committed
1149
      // set label
1150
      dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1151
      // free processed line:
Guolin Ke's avatar
Guolin Ke committed
1152
      ref_text_data[i].clear();
Guolin Ke's avatar
Guolin Ke committed
1153
1154
      // shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
      // text_reader_->Lines()[i].shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
1155
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1156
1157
      // push data
      for (auto& inner_data : oneline_features) {
1158
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1159
1160
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1161
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1162
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1163
1164
1165
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
          dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
1166
1167
1168
          if (dataset->has_raw()) {
            feature_row[feature_idx] = inner_data.second;
          }
Guolin Ke's avatar
Guolin Ke committed
1169
1170
        } else {
          if (inner_data.first == weight_idx_) {
1171
            dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1172
1173
1174
1175
1176
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
1177
1178
1179
1180
1181
1182
1183
1184
      if (dataset->has_raw()) {
        for (size_t j = 0; j < feature_row.size(); ++j) {
          int feat_ind = dataset->numeric_feature_map_[j];
          if (feat_ind >= 0) {
            dataset->raw_data_[feat_ind][i] = feature_row[j];
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1185
      dataset->FinishOneRow(tid, i, is_feature_added);
1186
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1187
    }
1188
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1189
  } else {
1190
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1191
    // if need to prediction with initial model
1192
    std::vector<double> init_score(dataset->num_data_ * num_class_);
1193
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label, feature_row)
Guolin Ke's avatar
Guolin Ke committed
1194
    for (data_size_t i = 0; i < dataset->num_data_; ++i) {
1195
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1196
1197
1198
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
Guolin Ke's avatar
Guolin Ke committed
1199
      parser->ParseOneLine(ref_text_data[i].c_str(), &oneline_features, &tmp_label);
Guolin Ke's avatar
Guolin Ke committed
1200
      // set initial score
Guolin Ke's avatar
Guolin Ke committed
1201
1202
      std::vector<double> oneline_init_score(num_class_);
      predict_fun_(oneline_features, oneline_init_score.data());
1203
      for (int k = 0; k < num_class_; ++k) {
1204
        init_score[k * dataset->num_data_ + i] = static_cast<double>(oneline_init_score[k]);
Guolin Ke's avatar
Guolin Ke committed
1205
1206
      }
      // set label
1207
      dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1208
1209
1210
1211
1212
      // free processed line:
      text_data[i].clear();
      // shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
      // text_reader_->Lines()[i].shrink_to_fit();
      // push data
Guolin Ke's avatar
Guolin Ke committed
1213
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1214
      for (auto& inner_data : oneline_features) {
1215
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1216
1217
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1218
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1219
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1220
1221
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
1222
          dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
1223
1224
1225
          if (dataset->has_raw()) {
            feature_row[feature_idx] = inner_data.second;
          }
Guolin Ke's avatar
Guolin Ke committed
1226
1227
        } else {
          if (inner_data.first == weight_idx_) {
1228
            dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1229
1230
1231
1232
1233
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1234
      dataset->FinishOneRow(tid, i, is_feature_added);
1235
1236
1237
1238
1239
1240
1241
1242
      if (dataset->has_raw()) {
        for (size_t j = 0; j < feature_row.size(); ++j) {
          int feat_ind = dataset->numeric_feature_map_[j];
          if (feat_ind >= 0) {
            dataset->raw_data_[feat_ind][i] = feature_row[j];
          }
        }
      }
1243
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1244
    }
1245
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1246
    // metadata_ will manage space of init_score
1247
    dataset->metadata_.SetInitScore(init_score.data(), dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1248
  }
Guolin Ke's avatar
Guolin Ke committed
1249
  dataset->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1250
  // text data can be free after loaded feature values
Guolin Ke's avatar
Guolin Ke committed
1251
  text_data->clear();
Guolin Ke's avatar
Guolin Ke committed
1252
1253
1254
}

/*! \brief Extract local features from file */
1255
1256
void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser* parser,
                                            const std::vector<data_size_t>& used_data_indices, Dataset* dataset) {
1257
  std::vector<double> init_score;
Guolin Ke's avatar
Guolin Ke committed
1258
  if (predict_fun_ != nullptr) {
1259
    init_score = std::vector<double>(dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1260
1261
1262
1263
1264
1265
  }
  std::function<void(data_size_t, const std::vector<std::string>&)> process_fun =
    [this, &init_score, &parser, &dataset]
  (data_size_t start_idx, const std::vector<std::string>& lines) {
    std::vector<std::pair<int, double>> oneline_features;
    double tmp_label = 0.0f;
1266
    std::vector<float> feature_row(dataset->num_features_);
1267
    OMP_INIT_EX();
1268
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label, feature_row)
Guolin Ke's avatar
Guolin Ke committed
1269
    for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) {
1270
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1271
1272
1273
1274
1275
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
      parser->ParseOneLine(lines[i].c_str(), &oneline_features, &tmp_label);
      // set initial score
Guolin Ke's avatar
Guolin Ke committed
1276
      if (!init_score.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1277
1278
        std::vector<double> oneline_init_score(num_class_);
        predict_fun_(oneline_features, oneline_init_score.data());
1279
        for (int k = 0; k < num_class_; ++k) {
1280
          init_score[k * dataset->num_data_ + start_idx + i] = static_cast<double>(oneline_init_score[k]);
Guolin Ke's avatar
Guolin Ke committed
1281
1282
1283
        }
      }
      // set label
1284
      dataset->metadata_.SetLabelAt(start_idx + i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1285
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1286
1287
      // push data
      for (auto& inner_data : oneline_features) {
1288
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1289
1290
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1291
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1292
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1293
1294
1295
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
          dataset->feature_groups_[group]->PushData(tid, sub_feature, start_idx + i, inner_data.second);
1296
1297
1298
          if (dataset->has_raw()) {
            feature_row[feature_idx] = inner_data.second;
          }
Guolin Ke's avatar
Guolin Ke committed
1299
1300
        } else {
          if (inner_data.first == weight_idx_) {
1301
            dataset->metadata_.SetWeightAt(start_idx + i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1302
1303
1304
1305
1306
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(start_idx + i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
1307
1308
1309
1310
1311
1312
1313
1314
      if (dataset->has_raw()) {
        for (size_t j = 0; j < feature_row.size(); ++j) {
          int feat_ind = dataset->numeric_feature_map_[j];
          if (feat_ind >= 0) {
            dataset->raw_data_[feat_ind][i] = feature_row[j];
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1315
      dataset->FinishOneRow(tid, i, is_feature_added);
1316
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1317
    }
1318
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1319
  };
1320
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
1321
  if (!used_data_indices.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1322
1323
1324
1325
1326
1327
1328
1329
    // only need part of data
    text_reader.ReadPartAndProcessParallel(used_data_indices, process_fun);
  } else {
    // need full data
    text_reader.ReadAllAndProcessParallel(process_fun);
  }

  // metadata_ will manage space of init_score
Guolin Ke's avatar
Guolin Ke committed
1330
  if (!init_score.empty()) {
1331
    dataset->metadata_.SetInitScore(init_score.data(), dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1332
  }
Guolin Ke's avatar
Guolin Ke committed
1333
  dataset->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1334
1335
1336
}

/*! \brief Check can load from binary file */
1337
std::string DatasetLoader::CheckCanLoadFromBin(const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
1338
1339
1340
  std::string bin_filename(filename);
  bin_filename.append(".bin");

1341
  auto reader = VirtualFileReader::Make(bin_filename.c_str());
Guolin Ke's avatar
Guolin Ke committed
1342

1343
  if (!reader->Init()) {
1344
    bin_filename = std::string(filename);
1345
1346
    reader = VirtualFileReader::Make(bin_filename.c_str());
    if (!reader->Init()) {
1347
      Log::Fatal("Cannot open data file %s", bin_filename.c_str());
1348
    }
1349
  }
1350
1351
1352
1353
1354

  size_t buffer_size = 256;
  auto buffer = std::vector<char>(buffer_size);
  // read size of token
  size_t size_of_token = std::strlen(Dataset::binary_file_token);
1355
  size_t read_cnt = reader->Read(buffer.data(), size_of_token);
1356
1357
  if (read_cnt == size_of_token
      && std::string(buffer.data()) == std::string(Dataset::binary_file_token)) {
1358
    return bin_filename;
Guolin Ke's avatar
Guolin Ke committed
1359
  } else {
1360
    return std::string();
Guolin Ke's avatar
Guolin Ke committed
1361
1362
1363
  }
}

1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376


std::vector<std::vector<double>> DatasetLoader::GetForcedBins(std::string forced_bins_path, int num_total_features,
                                                              const std::unordered_set<int>& categorical_features) {
  std::vector<std::vector<double>> forced_bins(num_total_features, std::vector<double>());
  if (forced_bins_path != "") {
    std::ifstream forced_bins_stream(forced_bins_path.c_str());
    if (forced_bins_stream.fail()) {
      Log::Warning("Could not open %s. Will ignore.", forced_bins_path.c_str());
    } else {
      std::stringstream buffer;
      buffer << forced_bins_stream.rdbuf();
      std::string err;
Guolin Ke's avatar
Guolin Ke committed
1377
      Json forced_bins_json = Json::parse(buffer.str(), &err);
1378
1379
1380
1381
      CHECK(forced_bins_json.is_array());
      std::vector<Json> forced_bins_arr = forced_bins_json.array_items();
      for (size_t i = 0; i < forced_bins_arr.size(); ++i) {
        int feature_num = forced_bins_arr[i]["feature"].int_value();
Nikita Titov's avatar
Nikita Titov committed
1382
        CHECK_LT(feature_num, num_total_features);
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
        if (categorical_features.count(feature_num)) {
          Log::Warning("Feature %d is categorical. Will ignore forced bins for this  feature.", feature_num);
        } else {
          std::vector<Json> bounds_arr = forced_bins_arr[i]["bin_upper_bound"].array_items();
          for (size_t j = 0; j < bounds_arr.size(); ++j) {
            forced_bins[feature_num].push_back(bounds_arr[j].number_value());
          }
        }
      }
      // remove duplicates
      for (int i = 0; i < num_total_features; ++i) {
        auto new_end = std::unique(forced_bins[i].begin(), forced_bins[i].end());
        forced_bins[i].erase(new_end, forced_bins[i].end());
      }
    }
  }
  return forced_bins;
}

1402
}  // namespace LightGBM