dataset_loader.cpp 52 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
#include <LightGBM/dataset_loader.h>

Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/network.h>
8
9
10
#include <LightGBM/utils/array_args.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
11

12
13
#include <fstream>

14
15
#include <LightGBM/json11.hpp>

16
17
using namespace json11;

Guolin Ke's avatar
Guolin Ke committed
18
19
namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
20
21
DatasetLoader::DatasetLoader(const Config& io_config, const PredictFunction& predict_fun, int num_class, const char* filename)
  :config_(io_config), random_(config_.data_random_seed), predict_fun_(predict_fun), num_class_(num_class) {
Guolin Ke's avatar
Guolin Ke committed
22
23
24
25
  label_idx_ = 0;
  weight_idx_ = NO_SPECIFIC;
  group_idx_ = NO_SPECIFIC;
  SetHeader(filename);
Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
30
}

DatasetLoader::~DatasetLoader() {
}

Guolin Ke's avatar
Guolin Ke committed
31
void DatasetLoader::SetHeader(const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
32
  std::unordered_map<std::string, int> name2idx;
Guolin Ke's avatar
Guolin Ke committed
33
34
  std::string name_prefix("name:");
  if (filename != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
35
    TextReader<data_size_t> text_reader(filename, config_.header);
Guolin Ke's avatar
Guolin Ke committed
36

Guolin Ke's avatar
Guolin Ke committed
37
    // get column names
Guolin Ke's avatar
Guolin Ke committed
38
    if (config_.header) {
Guolin Ke's avatar
Guolin Ke committed
39
      std::string first_line = text_reader.first_line();
40
      feature_names_ = Common::Split(first_line.c_str(), "\t,");
Guolin Ke's avatar
Guolin Ke committed
41
42
    }

Guolin Ke's avatar
Guolin Ke committed
43
    // load label idx first
Guolin Ke's avatar
Guolin Ke committed
44
45
46
    if (config_.label_column.size() > 0) {
      if (Common::StartsWith(config_.label_column, name_prefix)) {
        std::string name = config_.label_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
47
48
49
50
51
52
53
54
55
56
        label_idx_ = -1;
        for (int i = 0; i < static_cast<int>(feature_names_.size()); ++i) {
          if (name == feature_names_[i]) {
            label_idx_ = i;
            break;
          }
        }
        if (label_idx_ >= 0) {
          Log::Info("Using column %s as label", name.c_str());
        } else {
57
58
          Log::Fatal("Could not find label column %s in data file \n"
                     "or data file doesn't contain header", name.c_str());
Guolin Ke's avatar
Guolin Ke committed
59
        }
Guolin Ke's avatar
Guolin Ke committed
60
      } else {
Guolin Ke's avatar
Guolin Ke committed
61
        if (!Common::AtoiAndCheck(config_.label_column.c_str(), &label_idx_)) {
62
63
64
          Log::Fatal("label_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
65
66
        }
        Log::Info("Using column number %d as label", label_idx_);
Guolin Ke's avatar
Guolin Ke committed
67
68
      }
    }
Guolin Ke's avatar
Guolin Ke committed
69

Guolin Ke's avatar
Guolin Ke committed
70
    if (!feature_names_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
71
72
73
74
      // erase label column name
      feature_names_.erase(feature_names_.begin() + label_idx_);
      for (size_t i = 0; i < feature_names_.size(); ++i) {
        name2idx[feature_names_[i]] = static_cast<int>(i);
Guolin Ke's avatar
Guolin Ke committed
75
      }
Guolin Ke's avatar
Guolin Ke committed
76
77
78
    }

    // load ignore columns
Guolin Ke's avatar
Guolin Ke committed
79
80
81
    if (config_.ignore_column.size() > 0) {
      if (Common::StartsWith(config_.ignore_column, name_prefix)) {
        std::string names = config_.ignore_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
82
83
84
85
86
87
88
89
90
        for (auto name : Common::Split(names.c_str(), ',')) {
          if (name2idx.count(name) > 0) {
            int tmp = name2idx[name];
            ignore_features_.emplace(tmp);
          } else {
            Log::Fatal("Could not find ignore column %s in data file", name.c_str());
          }
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
91
        for (auto token : Common::Split(config_.ignore_column.c_str(), ',')) {
Guolin Ke's avatar
Guolin Ke committed
92
93
          int tmp = 0;
          if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
94
95
96
            Log::Fatal("ignore_column is not a number,\n"
                       "if you want to use a column name,\n"
                       "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
97
98
          }
          ignore_features_.emplace(tmp);
Guolin Ke's avatar
Guolin Ke committed
99
100
101
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
102
    // load weight idx
Guolin Ke's avatar
Guolin Ke committed
103
104
105
    if (config_.weight_column.size() > 0) {
      if (Common::StartsWith(config_.weight_column, name_prefix)) {
        std::string name = config_.weight_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
106
107
108
109
110
111
        if (name2idx.count(name) > 0) {
          weight_idx_ = name2idx[name];
          Log::Info("Using column %s as weight", name.c_str());
        } else {
          Log::Fatal("Could not find weight column %s in data file", name.c_str());
        }
Guolin Ke's avatar
Guolin Ke committed
112
      } else {
Guolin Ke's avatar
Guolin Ke committed
113
        if (!Common::AtoiAndCheck(config_.weight_column.c_str(), &weight_idx_)) {
114
115
116
          Log::Fatal("weight_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
117
118
        }
        Log::Info("Using column number %d as weight", weight_idx_);
Guolin Ke's avatar
Guolin Ke committed
119
      }
Guolin Ke's avatar
Guolin Ke committed
120
      ignore_features_.emplace(weight_idx_);
Guolin Ke's avatar
Guolin Ke committed
121
    }
Guolin Ke's avatar
Guolin Ke committed
122
    // load group idx
Guolin Ke's avatar
Guolin Ke committed
123
124
125
    if (config_.group_column.size() > 0) {
      if (Common::StartsWith(config_.group_column, name_prefix)) {
        std::string name = config_.group_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
126
127
128
129
130
131
132
        if (name2idx.count(name) > 0) {
          group_idx_ = name2idx[name];
          Log::Info("Using column %s as group/query id", name.c_str());
        } else {
          Log::Fatal("Could not find group/query column %s in data file", name.c_str());
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
133
        if (!Common::AtoiAndCheck(config_.group_column.c_str(), &group_idx_)) {
134
135
136
          Log::Fatal("group_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
137
138
139
140
        }
        Log::Info("Using column number %d as group/query id", group_idx_);
      }
      ignore_features_.emplace(group_idx_);
Guolin Ke's avatar
Guolin Ke committed
141
142
    }
  }
Guolin Ke's avatar
Guolin Ke committed
143
144
145
  if (config_.categorical_feature.size() > 0) {
    if (Common::StartsWith(config_.categorical_feature, name_prefix)) {
      std::string names = config_.categorical_feature.substr(name_prefix.size());
146
147
148
149
150
      for (auto name : Common::Split(names.c_str(), ',')) {
        if (name2idx.count(name) > 0) {
          int tmp = name2idx[name];
          categorical_features_.emplace(tmp);
        } else {
Guolin Ke's avatar
Guolin Ke committed
151
          Log::Fatal("Could not find categorical_feature %s in data file", name.c_str());
152
153
154
        }
      }
    } else {
Guolin Ke's avatar
Guolin Ke committed
155
      for (auto token : Common::Split(config_.categorical_feature.c_str(), ',')) {
156
157
        int tmp = 0;
        if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
Guolin Ke's avatar
Guolin Ke committed
158
          Log::Fatal("categorical_feature is not a number,\n"
159
160
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
161
162
163
164
165
        }
        categorical_features_.emplace(tmp);
      }
    }
  }
Guolin Ke's avatar
Guolin Ke committed
166
167
}

168
Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_machines) {
Guolin Ke's avatar
Guolin Ke committed
169
  // don't support query id in data file when training in parallel
Guolin Ke's avatar
Guolin Ke committed
170
  if (num_machines > 1 && !config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
171
    if (group_idx_ > 0) {
172
173
      Log::Fatal("Using a query id without pre-partitioning the data file is not supported for parallel training.\n"
                 "Please use an additional query file or pre-partition the data");
Guolin Ke's avatar
Guolin Ke committed
174
175
    }
  }
Guolin Ke's avatar
Guolin Ke committed
176
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
Guolin Ke's avatar
Guolin Ke committed
177
178
  data_size_t num_global_data = 0;
  std::vector<data_size_t> used_data_indices;
179
180
  auto bin_filename = CheckCanLoadFromBin(filename);
  if (bin_filename.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
181
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_));
Guolin Ke's avatar
Guolin Ke committed
182
183
184
185
    if (parser == nullptr) {
      Log::Fatal("Could not recognize data format of %s", filename);
    }
    dataset->data_filename_ = filename;
Guolin Ke's avatar
Guolin Ke committed
186
    dataset->label_idx_ = label_idx_;
187
    dataset->metadata_.Init(filename);
Guolin Ke's avatar
Guolin Ke committed
188
    if (!config_.two_round) {
Guolin Ke's avatar
Guolin Ke committed
189
      // read data to memory
190
      auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
191
192
193
194
      dataset->num_data_ = static_cast<data_size_t>(text_data.size());
      // sample data
      auto sample_data = SampleTextDataFromMemory(text_data);
      // construct feature bin mappers
Guolin Ke's avatar
Guolin Ke committed
195
      ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
196
      // initialize label
197
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
198
      // extract features
Guolin Ke's avatar
Guolin Ke committed
199
      ExtractFeaturesFromMemory(&text_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
200
201
202
203
204
205
206
207
208
209
      text_data.clear();
    } else {
      // sample data from file
      auto sample_data = SampleTextDataFromFile(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
      if (used_data_indices.size() > 0) {
        dataset->num_data_ = static_cast<data_size_t>(used_data_indices.size());
      } else {
        dataset->num_data_ = num_global_data;
      }
      // construct feature bin mappers
Guolin Ke's avatar
Guolin Ke committed
210
      ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
211
      // initialize label
212
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
213
      Log::Debug("Making second pass...");
Guolin Ke's avatar
Guolin Ke committed
214
      // extract features
Guolin Ke's avatar
Guolin Ke committed
215
      ExtractFeaturesFromFile(filename, parser.get(), used_data_indices, dataset.get());
Guolin Ke's avatar
Guolin Ke committed
216
217
218
    }
  } else {
    // load data from binary file
219
    dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), rank, num_machines, &num_global_data, &used_data_indices));
Guolin Ke's avatar
Guolin Ke committed
220
221
222
223
  }
  // check meta data
  dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);
  // need to check training data
Guolin Ke's avatar
Guolin Ke committed
224
225
  CheckDataset(dataset.get());
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
226
227
228
229
}



230
Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, const Dataset* train_data) {
Guolin Ke's avatar
Guolin Ke committed
231
232
  data_size_t num_global_data = 0;
  std::vector<data_size_t> used_data_indices;
Guolin Ke's avatar
Guolin Ke committed
233
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
234
235
  auto bin_filename = CheckCanLoadFromBin(filename);
  if (bin_filename.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
236
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_));
Guolin Ke's avatar
Guolin Ke committed
237
238
239
240
    if (parser == nullptr) {
      Log::Fatal("Could not recognize data format of %s", filename);
    }
    dataset->data_filename_ = filename;
Guolin Ke's avatar
Guolin Ke committed
241
    dataset->label_idx_ = label_idx_;
242
    dataset->metadata_.Init(filename);
Guolin Ke's avatar
Guolin Ke committed
243
    if (!config_.two_round) {
Guolin Ke's avatar
Guolin Ke committed
244
245
246
247
      // read data in memory
      auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, 0, 1, &num_global_data, &used_data_indices);
      dataset->num_data_ = static_cast<data_size_t>(text_data.size());
      // initialize label
248
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
249
      dataset->CreateValid(train_data);
Guolin Ke's avatar
Guolin Ke committed
250
      // extract features
Guolin Ke's avatar
Guolin Ke committed
251
      ExtractFeaturesFromMemory(&text_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
252
253
      text_data.clear();
    } else {
Guolin Ke's avatar
Guolin Ke committed
254
      TextReader<data_size_t> text_reader(filename, config_.header);
Guolin Ke's avatar
Guolin Ke committed
255
256
257
258
      // Get number of lines of data file
      dataset->num_data_ = static_cast<data_size_t>(text_reader.CountLine());
      num_global_data = dataset->num_data_;
      // initialize label
259
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
260
      dataset->CreateValid(train_data);
Guolin Ke's avatar
Guolin Ke committed
261
      // extract features
Guolin Ke's avatar
Guolin Ke committed
262
      ExtractFeaturesFromFile(filename, parser.get(), used_data_indices, dataset.get());
Guolin Ke's avatar
Guolin Ke committed
263
264
265
    }
  } else {
    // load data from binary file
266
    dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), 0, 1, &num_global_data, &used_data_indices));
Guolin Ke's avatar
Guolin Ke committed
267
268
269
270
  }
  // not need to check validation data
  // check meta data
  dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
271
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
272
273
}

274
275
276
Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* bin_filename,
                                        int rank, int num_machines, int* num_global_data,
                                        std::vector<data_size_t>* used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
277
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
278
  auto reader = VirtualFileReader::Make(bin_filename);
Guolin Ke's avatar
Guolin Ke committed
279
  dataset->data_filename_ = data_filename;
280
  if (!reader->Init()) {
Guolin Ke's avatar
Guolin Ke committed
281
282
283
284
285
    Log::Fatal("Could not read binary data from %s", bin_filename);
  }

  // buffer to read binary file
  size_t buffer_size = 16 * 1024 * 1024;
Guolin Ke's avatar
Guolin Ke committed
286
  auto buffer = std::vector<char>(buffer_size);
287

288
289
  // check token
  size_t size_of_token = std::strlen(Dataset::binary_file_token);
290
291
  size_t read_cnt = reader->Read(buffer.data(), sizeof(char) * size_of_token);
  if (read_cnt != sizeof(char) * size_of_token) {
292
293
294
    Log::Fatal("Binary file error: token has the wrong size");
  }
  if (std::string(buffer.data()) != std::string(Dataset::binary_file_token)) {
295
    Log::Fatal("Input file is not LightGBM binary file");
296
  }
Guolin Ke's avatar
Guolin Ke committed
297
298

  // read size of header
299
  read_cnt = reader->Read(buffer.data(), sizeof(size_t));
Guolin Ke's avatar
Guolin Ke committed
300

301
  if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
302
303
304
    Log::Fatal("Binary file error: header has the wrong size");
  }

Guolin Ke's avatar
Guolin Ke committed
305
  size_t size_of_head = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
306
307
308
309

  // re-allocmate space if not enough
  if (size_of_head > buffer_size) {
    buffer_size = size_of_head;
Guolin Ke's avatar
Guolin Ke committed
310
    buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
311
312
  }
  // read header
313
  read_cnt = reader->Read(buffer.data(), size_of_head);
Guolin Ke's avatar
Guolin Ke committed
314
315
316
317
318

  if (read_cnt != size_of_head) {
    Log::Fatal("Binary file error: header is incorrect");
  }
  // get header
Guolin Ke's avatar
Guolin Ke committed
319
  const char* mem_ptr = buffer.data();
Guolin Ke's avatar
Guolin Ke committed
320
321
322
323
324
325
  dataset->num_data_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
  mem_ptr += sizeof(dataset->num_data_);
  dataset->num_features_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->num_features_);
  dataset->num_total_features_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->num_total_features_);
Guolin Ke's avatar
Guolin Ke committed
326
327
  dataset->label_idx_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->label_idx_);
328
329
330
331
332
333
334
335
336
337
  dataset->max_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->max_bin_);
  dataset->bin_construct_sample_cnt_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->bin_construct_sample_cnt_);
  dataset->min_data_in_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->min_data_in_bin_);
  dataset->use_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
  mem_ptr += sizeof(dataset->use_missing_);
  dataset->zero_as_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
  mem_ptr += sizeof(dataset->zero_as_missing_);
Guolin Ke's avatar
Guolin Ke committed
338
339
  const int* tmp_feature_map = reinterpret_cast<const int*>(mem_ptr);
  dataset->used_feature_map_.clear();
Guolin Ke's avatar
Guolin Ke committed
340
  for (int i = 0; i < dataset->num_total_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
341
342
    dataset->used_feature_map_.push_back(tmp_feature_map[i]);
  }
Guolin Ke's avatar
Guolin Ke committed
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
  mem_ptr += sizeof(int) * dataset->num_total_features_;
  // num_groups
  dataset->num_groups_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->num_groups_);
  // real_feature_idx_
  const int* tmp_ptr_real_feature_idx_ = reinterpret_cast<const int*>(mem_ptr);
  dataset->real_feature_idx_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->real_feature_idx_.push_back(tmp_ptr_real_feature_idx_[i]);
  }
  mem_ptr += sizeof(int) * dataset->num_features_;
  // feature2group
  const int* tmp_ptr_feature2group = reinterpret_cast<const int*>(mem_ptr);
  dataset->feature2group_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->feature2group_.push_back(tmp_ptr_feature2group[i]);
  }
  mem_ptr += sizeof(int) * dataset->num_features_;
  // feature2subfeature
  const int* tmp_ptr_feature2subfeature = reinterpret_cast<const int*>(mem_ptr);
  dataset->feature2subfeature_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->feature2subfeature_.push_back(tmp_ptr_feature2subfeature[i]);
  }
  mem_ptr += sizeof(int) * dataset->num_features_;
  // group_bin_boundaries
  const uint64_t* tmp_ptr_group_bin_boundaries = reinterpret_cast<const uint64_t*>(mem_ptr);
  dataset->group_bin_boundaries_.clear();
  for (int i = 0; i < dataset->num_groups_ + 1; ++i) {
    dataset->group_bin_boundaries_.push_back(tmp_ptr_group_bin_boundaries[i]);
  }
  mem_ptr += sizeof(uint64_t) * (dataset->num_groups_ + 1);

  // group_feature_start_
  const int* tmp_ptr_group_feature_start = reinterpret_cast<const int*>(mem_ptr);
  dataset->group_feature_start_.clear();
379
  for (int i = 0; i < dataset->num_groups_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
380
381
382
383
384
385
386
387
388
389
390
391
    dataset->group_feature_start_.push_back(tmp_ptr_group_feature_start[i]);
  }
  mem_ptr += sizeof(int) * (dataset->num_groups_);

  // group_feature_cnt_
  const int* tmp_ptr_group_feature_cnt = reinterpret_cast<const int*>(mem_ptr);
  dataset->group_feature_cnt_.clear();
  for (int i = 0; i < dataset->num_groups_; ++i) {
    dataset->group_feature_cnt_.push_back(tmp_ptr_group_feature_cnt[i]);
  }
  mem_ptr += sizeof(int) * (dataset->num_groups_);

Belinda Trotta's avatar
Belinda Trotta committed
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
  if (!config_.max_bin_by_feature.empty()) {
    CHECK(static_cast<size_t>(dataset->num_total_features_) == config_.max_bin_by_feature.size());
    CHECK(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())) > 1);
    dataset->max_bin_by_feature_.resize(dataset->num_total_features_);
    dataset->max_bin_by_feature_.assign(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end());
  } else {
    const int32_t* tmp_ptr_max_bin_by_feature = reinterpret_cast<const int32_t*>(mem_ptr);
    dataset->max_bin_by_feature_.clear();
    for (int i = 0; i < dataset->num_total_features_; ++i) {
      dataset->max_bin_by_feature_.push_back(tmp_ptr_max_bin_by_feature[i]);
    }
  }
  mem_ptr += sizeof(int32_t) * (dataset->num_total_features_);
  if (ArrayArgs<int32_t>::CheckAll(dataset->max_bin_by_feature_, -1)) {
    dataset->max_bin_by_feature_.clear();
  }

Guolin Ke's avatar
Guolin Ke committed
409
  // get feature names
Guolin Ke's avatar
Guolin Ke committed
410
  dataset->feature_names_.clear();
Guolin Ke's avatar
Guolin Ke committed
411
412
413
414
415
416
417
418
419
420
  // write feature names
  for (int i = 0; i < dataset->num_total_features_; ++i) {
    int str_len = *(reinterpret_cast<const int*>(mem_ptr));
    mem_ptr += sizeof(int);
    std::stringstream str_buf;
    for (int j = 0; j < str_len; ++j) {
      char tmp_char = *(reinterpret_cast<const char*>(mem_ptr));
      mem_ptr += sizeof(char);
      str_buf << tmp_char;
    }
Guolin Ke's avatar
Guolin Ke committed
421
    dataset->feature_names_.emplace_back(str_buf.str());
Guolin Ke's avatar
Guolin Ke committed
422
  }
423
424
425
426
427
428
429
430
431
432
433
434
435
  // get forced_bin_bounds_
  dataset->forced_bin_bounds_ = std::vector<std::vector<double>>(dataset->num_total_features_, std::vector<double>());
  for (int i = 0; i < dataset->num_total_features_; ++i) {
    int num_bounds = *(reinterpret_cast<const int*>(mem_ptr));
    mem_ptr += sizeof(int);
    dataset->forced_bin_bounds_[i] = std::vector<double>();
    const double* tmp_ptr_forced_bounds = reinterpret_cast<const double*>(mem_ptr);
    for (int j = 0; j < num_bounds; ++j) {
      double bound = tmp_ptr_forced_bounds[j];
      dataset->forced_bin_bounds_[i].push_back(bound);
    }
    mem_ptr += num_bounds * sizeof(double);
  }
Guolin Ke's avatar
Guolin Ke committed
436
437

  // read size of meta data
438
  read_cnt = reader->Read(buffer.data(), sizeof(size_t));
Guolin Ke's avatar
Guolin Ke committed
439

440
  if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
441
442
443
    Log::Fatal("Binary file error: meta data has the wrong size");
  }

Guolin Ke's avatar
Guolin Ke committed
444
  size_t size_of_metadata = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
445
446
447
448

  // re-allocate space if not enough
  if (size_of_metadata > buffer_size) {
    buffer_size = size_of_metadata;
Guolin Ke's avatar
Guolin Ke committed
449
    buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
450
451
  }
  //  read meta data
452
  read_cnt = reader->Read(buffer.data(), size_of_metadata);
Guolin Ke's avatar
Guolin Ke committed
453
454
455
456
457

  if (read_cnt != size_of_metadata) {
    Log::Fatal("Binary file error: meta data is incorrect");
  }
  // load meta data
Guolin Ke's avatar
Guolin Ke committed
458
  dataset->metadata_.LoadFromMemory(buffer.data());
Guolin Ke's avatar
Guolin Ke committed
459

460
461
  *num_global_data = dataset->num_data_;
  used_data_indices->clear();
Guolin Ke's avatar
Guolin Ke committed
462
  // sample local used data if need to partition
Guolin Ke's avatar
Guolin Ke committed
463
  if (num_machines > 1 && !config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
464
465
466
467
    const data_size_t* query_boundaries = dataset->metadata_.query_boundaries();
    if (query_boundaries == nullptr) {
      // if not contain query file, minimal sample unit is one record
      for (data_size_t i = 0; i < dataset->num_data_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
468
        if (random_.NextShort(0, num_machines) == rank) {
469
          used_data_indices->push_back(i);
Guolin Ke's avatar
Guolin Ke committed
470
471
472
473
474
475
476
477
478
        }
      }
    } else {
      // if contain query file, minimal sample unit is one query
      data_size_t num_queries = dataset->metadata_.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      for (data_size_t i = 0; i < dataset->num_data_; ++i) {
        if (qid >= num_queries) {
479
480
          Log::Fatal("Current query exceeds the range of the query file,\n"
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
481
482
483
484
        }
        if (i >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
485
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
486
487
488
489
490
            is_query_used = true;
          }
          ++qid;
        }
        if (is_query_used) {
491
          used_data_indices->push_back(i);
Guolin Ke's avatar
Guolin Ke committed
492
493
494
        }
      }
    }
495
    dataset->num_data_ = static_cast<data_size_t>((*used_data_indices).size());
Guolin Ke's avatar
Guolin Ke committed
496
  }
497
  dataset->metadata_.PartitionLabel(*used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
498
  // read feature data
Guolin Ke's avatar
Guolin Ke committed
499
  for (int i = 0; i < dataset->num_groups_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
500
    // read feature size
501
502
    read_cnt = reader->Read(buffer.data(), sizeof(size_t));
    if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
503
504
      Log::Fatal("Binary file error: feature %d has the wrong size", i);
    }
Guolin Ke's avatar
Guolin Ke committed
505
    size_t size_of_feature = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
506
507
508
    // re-allocate space if not enough
    if (size_of_feature > buffer_size) {
      buffer_size = size_of_feature;
Guolin Ke's avatar
Guolin Ke committed
509
      buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
510
511
    }

512
    read_cnt = reader->Read(buffer.data(), size_of_feature);
Guolin Ke's avatar
Guolin Ke committed
513
514
515
516

    if (read_cnt != size_of_feature) {
      Log::Fatal("Binary file error: feature %d is incorrect, read count: %d", i, read_cnt);
    }
Guolin Ke's avatar
Guolin Ke committed
517
    dataset->feature_groups_.emplace_back(std::unique_ptr<FeatureGroup>(
518
519
      new FeatureGroup(buffer.data(),
                       *num_global_data,
520
                       *used_data_indices)));
Guolin Ke's avatar
Guolin Ke committed
521
  }
Guolin Ke's avatar
Guolin Ke committed
522
  dataset->feature_groups_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
523
  dataset->is_finish_load_ = true;
Guolin Ke's avatar
Guolin Ke committed
524
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
525
526
}

527

528
529
530
Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
                                               int** sample_indices, int num_col, const int* num_per_col,
                                               size_t total_sample_size, data_size_t num_data) {
531
532
533
534
535
  int num_total_features = num_col;
  if (Network::num_machines() > 1) {
    num_total_features = Network::GlobalSyncUpByMax(num_total_features);
  }
  std::vector<std::unique_ptr<BinMapper>> bin_mappers(num_total_features);
536
537
  // fill feature_names_ if not header
  if (feature_names_.empty()) {
538
    for (int i = 0; i < num_col; ++i) {
539
540
541
542
543
      std::stringstream str_buf;
      str_buf << "Column_" << i;
      feature_names_.push_back(str_buf.str());
    }
  }
Belinda Trotta's avatar
Belinda Trotta committed
544
545
546
547
  if (!config_.max_bin_by_feature.empty()) {
    CHECK(static_cast<size_t>(num_col) == config_.max_bin_by_feature.size());
    CHECK(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())) > 1);
  }
548
549
550
551
552

  // get forced split
  std::string forced_bins_path = config_.forcedbins_filename;
  std::vector<std::vector<double>> forced_bin_bounds = DatasetLoader::GetForcedBins(forced_bins_path, num_col, categorical_features_);

Guolin Ke's avatar
Guolin Ke committed
553
  const data_size_t filter_cnt = static_cast<data_size_t>(
Guolin Ke's avatar
Guolin Ke committed
554
    static_cast<double>(config_.min_data_in_leaf * total_sample_size) / num_data);
555
556
557
558
559
560
561
562
563
564
565
566
567
  if (Network::num_machines() == 1) {
    // if only one machine, find bin locally
    OMP_INIT_EX();
    #pragma omp parallel for schedule(guided)
    for (int i = 0; i < num_col; ++i) {
      OMP_LOOP_EX_BEGIN();
      if (ignore_features_.count(i) > 0) {
        bin_mappers[i] = nullptr;
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(i)) {
        bin_type = BinType::CategoricalBin;
568
569
570
571
        bool feat_is_unconstrained = ((config_.monotone_constraints.size() == 0) || (config_.monotone_constraints[i] == 0));
        if (!feat_is_unconstrained) {
            Log::Fatal("The output cannot be monotone with respect to categorical features");
        }
572
573
      }
      bin_mappers[i].reset(new BinMapper());
Belinda Trotta's avatar
Belinda Trotta committed
574
575
      if (config_.max_bin_by_feature.empty()) {
        bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
576
                                config_.max_bin, config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter,
577
578
                                bin_type, config_.use_missing, config_.zero_as_missing,
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
579
580
      } else {
        bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
581
                                config_.max_bin_by_feature[i], config_.min_data_in_bin,
582
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
583
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
584
      }
585
586
587
588
589
590
591
592
593
594
595
596
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();
  } else {
    // if have multi-machines, need to find bin distributed
    // different machines will find bin for different features
    int num_machines = Network::num_machines();
    int rank = Network::rank();
    // start and len will store the process feature indices for different machines
    // machine i will find bins for features in [ start[i], start[i] + len[i] )
    std::vector<int> start(num_machines);
    std::vector<int> len(num_machines);
597
    int step = (num_total_features + num_machines - 1) / num_machines;
598
599
600
601
    if (step < 1) { step = 1; }

    start[0] = 0;
    for (int i = 0; i < num_machines - 1; ++i) {
602
      len[i] = std::min(step, num_total_features - start[i]);
603
604
      start[i + 1] = start[i] + len[i];
    }
605
    len[num_machines - 1] = num_total_features - start[num_machines - 1];
606
607
608
609
610
611
612
613
614
615
616
617
    OMP_INIT_EX();
    #pragma omp parallel for schedule(guided)
    for (int i = 0; i < len[rank]; ++i) {
      OMP_LOOP_EX_BEGIN();
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(start[rank] + i)) {
        bin_type = BinType::CategoricalBin;
      }
      bin_mappers[i].reset(new BinMapper());
618
619
620
      if (num_col <= start[rank] + i) {
        continue;
      }
Belinda Trotta's avatar
Belinda Trotta committed
621
      if (config_.max_bin_by_feature.empty()) {
622
623
        bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
                                total_sample_size, config_.max_bin, config_.min_data_in_bin,
624
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
625
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
626
      } else {
627
628
        bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
                                total_sample_size, config_.max_bin_by_feature[start[rank] + i],
629
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
630
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
631
      }
632
633
      OMP_LOOP_EX_END();
    }
634
    comm_size_t self_buf_size = 0;
635
    for (int i = 0; i < len[rank]; ++i) {
636
637
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
638
      }
639
      self_buf_size += static_cast<comm_size_t>(bin_mappers[i]->SizesInByte());
Guolin Ke's avatar
Guolin Ke committed
640
    }
641
642
    std::vector<char> input_buffer(self_buf_size);
    auto cp_ptr = input_buffer.data();
643
644
645
646
    for (int i = 0; i < len[rank]; ++i) {
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
647
648
      bin_mappers[i]->CopyTo(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
649
650
651
      // free
      bin_mappers[i].reset(nullptr);
    }
652
653
654
655
    std::vector<comm_size_t> size_len = Network::GlobalArray(self_buf_size);
    std::vector<comm_size_t> size_start(num_machines, 0);
    for (int i = 1; i < num_machines; ++i) {
      size_start[i] = size_start[i - 1] + size_len[i - 1];
656
    }
657
658
    comm_size_t total_buffer_size = size_start[num_machines - 1] + size_len[num_machines - 1];
    std::vector<char> output_buffer(total_buffer_size);
659
    // gather global feature bin mappers
660
661
    Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), total_buffer_size);
    cp_ptr = output_buffer.data();
662
    // restore features bins from buffer
663
    for (int i = 0; i < num_total_features; ++i) {
664
665
666
667
668
      if (ignore_features_.count(i) > 0) {
        bin_mappers[i] = nullptr;
        continue;
      }
      bin_mappers[i].reset(new BinMapper());
669
670
      bin_mappers[i]->CopyFrom(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
671
    }
Guolin Ke's avatar
Guolin Ke committed
672
  }
Guolin Ke's avatar
Guolin Ke committed
673
  auto dataset = std::unique_ptr<Dataset>(new Dataset(num_data));
Guolin Ke's avatar
Guolin Ke committed
674
  dataset->Construct(&bin_mappers, num_total_features, forced_bin_bounds, sample_indices, sample_values, num_per_col, num_col, total_sample_size, config_);
675
  dataset->set_feature_names(feature_names_);
Guolin Ke's avatar
Guolin Ke committed
676
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
677
}
Guolin Ke's avatar
Guolin Ke committed
678
679
680
681
682
683


// ---- private functions ----

void DatasetLoader::CheckDataset(const Dataset* dataset) {
  if (dataset->num_data_ <= 0) {
Guolin Ke's avatar
Guolin Ke committed
684
    Log::Fatal("Data file %s is empty", dataset->data_filename_.c_str());
Guolin Ke's avatar
Guolin Ke committed
685
  }
686
687
  if (dataset->feature_names_.size() != static_cast<size_t>(dataset->num_total_features_)) {
    Log::Fatal("Size of feature name error, should be %d, got %d", dataset->num_total_features_,
688
               static_cast<int>(dataset->feature_names_.size()));
689
  }
Guolin Ke's avatar
Guolin Ke committed
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
  bool is_feature_order_by_group = true;
  int last_group = -1;
  int last_sub_feature = -1;
  // if features are ordered, not need to use hist_buf
  for (int i = 0; i < dataset->num_features_; ++i) {
    int group = dataset->feature2group_[i];
    int sub_feature = dataset->feature2subfeature_[i];
    if (group < last_group) {
      is_feature_order_by_group = false;
    } else if (group == last_group) {
      if (sub_feature <= last_sub_feature) {
        is_feature_order_by_group = false;
        break;
      }
    }
    last_group = group;
    last_sub_feature = sub_feature;
  }
  if (!is_feature_order_by_group) {
709
    Log::Fatal("Features in dataset should be ordered by group");
Guolin Ke's avatar
Guolin Ke committed
710
  }
Guolin Ke's avatar
Guolin Ke committed
711
712
713
}

std::vector<std::string> DatasetLoader::LoadTextDataToMemory(const char* filename, const Metadata& metadata,
714
715
                                                             int rank, int num_machines, int* num_global_data,
                                                             std::vector<data_size_t>* used_data_indices) {
716
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
717
  used_data_indices->clear();
Guolin Ke's avatar
Guolin Ke committed
718
  if (num_machines == 1 || config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
719
720
721
722
723
724
725
726
727
    // read all lines
    *num_global_data = text_reader.ReadAllLines();
  } else {  // need partition data
            // get query data
    const data_size_t* query_boundaries = metadata.query_boundaries();

    if (query_boundaries == nullptr) {
      // if not contain query data, minimal sample unit is one record
      *num_global_data = text_reader.ReadAndFilterLines([this, rank, num_machines](data_size_t) {
Guolin Ke's avatar
Guolin Ke committed
728
        if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
729
730
731
732
733
734
735
736
737
738
739
740
741
742
          return true;
        } else {
          return false;
        }
      }, used_data_indices);
    } else {
      // if contain query data, minimal sample unit is one query
      data_size_t num_queries = metadata.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      *num_global_data = text_reader.ReadAndFilterLines(
        [this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries]
      (data_size_t line_idx) {
        if (qid >= num_queries) {
743
744
          Log::Fatal("Current query exceeds the range of the query file,\n"
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
745
746
747
748
        }
        if (line_idx >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
749
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
750
751
752
753
754
755
756
757
758
759
760
761
            is_query_used = true;
          }
          ++qid;
        }
        return is_query_used;
      }, used_data_indices);
    }
  }
  return std::move(text_reader.Lines());
}

std::vector<std::string> DatasetLoader::SampleTextDataFromMemory(const std::vector<std::string>& data) {
Guolin Ke's avatar
Guolin Ke committed
762
  int sample_cnt = config_.bin_construct_sample_cnt;
763
764
  if (static_cast<size_t>(sample_cnt) > data.size()) {
    sample_cnt = static_cast<int>(data.size());
765
  }
766
  auto sample_indices = random_.Sample(static_cast<int>(data.size()), sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
767
  std::vector<std::string> out(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
768
769
  for (size_t i = 0; i < sample_indices.size(); ++i) {
    const size_t idx = sample_indices[i];
Guolin Ke's avatar
Guolin Ke committed
770
    out[i] = data[idx];
Guolin Ke's avatar
Guolin Ke committed
771
772
773
774
  }
  return out;
}

775
776
777
std::vector<std::string> DatasetLoader::SampleTextDataFromFile(const char* filename, const Metadata& metadata,
                                                               int rank, int num_machines, int* num_global_data,
                                                               std::vector<data_size_t>* used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
778
  const data_size_t sample_cnt = static_cast<data_size_t>(config_.bin_construct_sample_cnt);
779
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
780
  std::vector<std::string> out_data;
Guolin Ke's avatar
Guolin Ke committed
781
  if (num_machines == 1 || config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
782
    *num_global_data = static_cast<data_size_t>(text_reader.SampleFromFile(&random_, sample_cnt, &out_data));
Guolin Ke's avatar
Guolin Ke committed
783
784
785
786
787
788
789
  } else {  // need partition data
            // get query data
    const data_size_t* query_boundaries = metadata.query_boundaries();
    if (query_boundaries == nullptr) {
      // if not contain query file, minimal sample unit is one record
      *num_global_data = text_reader.SampleAndFilterFromFile([this, rank, num_machines]
      (data_size_t) {
Guolin Ke's avatar
Guolin Ke committed
790
        if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
791
792
793
794
          return true;
        } else {
          return false;
        }
Guolin Ke's avatar
Guolin Ke committed
795
      }, used_data_indices, &random_, sample_cnt, &out_data);
Guolin Ke's avatar
Guolin Ke committed
796
797
798
799
800
801
802
803
804
    } else {
      // if contain query file, minimal sample unit is one query
      data_size_t num_queries = metadata.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      *num_global_data = text_reader.SampleAndFilterFromFile(
        [this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries]
      (data_size_t line_idx) {
        if (qid >= num_queries) {
805
806
          Log::Fatal("Query id exceeds the range of the query file, "
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
807
808
809
810
        }
        if (line_idx >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
811
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
812
813
814
815
816
            is_query_used = true;
          }
          ++qid;
        }
        return is_query_used;
Guolin Ke's avatar
Guolin Ke committed
817
      }, used_data_indices, &random_, sample_cnt, &out_data);
Guolin Ke's avatar
Guolin Ke committed
818
819
820
821
822
    }
  }
  return out_data;
}

823
824
825
void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
                                                    const std::vector<std::string>& sample_data,
                                                    const Parser* parser, Dataset* dataset) {
Guolin Ke's avatar
Guolin Ke committed
826
  std::vector<std::vector<double>> sample_values;
Guolin Ke's avatar
Guolin Ke committed
827
  std::vector<std::vector<int>> sample_indices;
Guolin Ke's avatar
Guolin Ke committed
828
829
  std::vector<std::pair<int, double>> oneline_features;
  double label;
Guolin Ke's avatar
Guolin Ke committed
830
  for (int i = 0; i < static_cast<int>(sample_data.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
831
832
833
834
    oneline_features.clear();
    // parse features
    parser->ParseOneLine(sample_data[i].c_str(), &oneline_features, &label);
    for (std::pair<int, double>& inner_data : oneline_features) {
835
      if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
Guolin Ke's avatar
Guolin Ke committed
836
837
        sample_values.resize(inner_data.first + 1);
        sample_indices.resize(inner_data.first + 1);
838
      }
Guolin Ke's avatar
Guolin Ke committed
839
      if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
Guolin Ke's avatar
Guolin Ke committed
840
841
        sample_values[inner_data.first].emplace_back(inner_data.second);
        sample_indices[inner_data.first].emplace_back(i);
Guolin Ke's avatar
Guolin Ke committed
842
843
844
845
      }
    }
  }

Guolin Ke's avatar
Guolin Ke committed
846
  dataset->feature_groups_.clear();
847
848
849
850
851
852
  dataset->num_total_features_ = std::max(static_cast<int>(sample_values.size()), parser->NumFeatures());
  if (num_machines > 1) {
    dataset->num_total_features_ = Network::GlobalSyncUpByMax(dataset->num_total_features_);
  }
  if (!feature_names_.empty()) {
    CHECK(dataset->num_total_features_ == static_cast<int>(feature_names_.size()));
853
  }
Guolin Ke's avatar
Guolin Ke committed
854

Belinda Trotta's avatar
Belinda Trotta committed
855
856
857
858
859
  if (!config_.max_bin_by_feature.empty()) {
    CHECK(static_cast<size_t>(dataset->num_total_features_) == config_.max_bin_by_feature.size());
    CHECK(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())) > 1);
  }

860
861
  // get forced split
  std::string forced_bins_path = config_.forcedbins_filename;
862
863
  std::vector<std::vector<double>> forced_bin_bounds = DatasetLoader::GetForcedBins(forced_bins_path,
                                                                                    dataset->num_total_features_,
864
865
                                                                                    categorical_features_);

Guolin Ke's avatar
Guolin Ke committed
866
867
868
869
870
871
  // check the range of label_idx, weight_idx and group_idx
  CHECK(label_idx_ >= 0 && label_idx_ <= dataset->num_total_features_);
  CHECK(weight_idx_ < 0 || weight_idx_ < dataset->num_total_features_);
  CHECK(group_idx_ < 0 || group_idx_ < dataset->num_total_features_);

  // fill feature_names_ if not header
Guolin Ke's avatar
Guolin Ke committed
872
  if (feature_names_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
873
874
875
876
877
878
    for (int i = 0; i < dataset->num_total_features_; ++i) {
      std::stringstream str_buf;
      str_buf << "Column_" << i;
      feature_names_.push_back(str_buf.str());
    }
  }
879
  dataset->set_feature_names(feature_names_);
Guolin Ke's avatar
Guolin Ke committed
880
  std::vector<std::unique_ptr<BinMapper>> bin_mappers(dataset->num_total_features_);
Guolin Ke's avatar
Guolin Ke committed
881
  const data_size_t filter_cnt = static_cast<data_size_t>(
Guolin Ke's avatar
Guolin Ke committed
882
    static_cast<double>(config_.min_data_in_leaf* sample_data.size()) / dataset->num_data_);
Guolin Ke's avatar
Guolin Ke committed
883
884
885
  // start find bins
  if (num_machines == 1) {
    // if only one machine, find bin locally
886
    OMP_INIT_EX();
887
    #pragma omp parallel for schedule(guided)
Guolin Ke's avatar
Guolin Ke committed
888
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
889
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
890
      if (ignore_features_.count(i) > 0) {
Guolin Ke's avatar
Guolin Ke committed
891
        bin_mappers[i] = nullptr;
Guolin Ke's avatar
Guolin Ke committed
892
893
        continue;
      }
894
895
896
897
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(i)) {
        bin_type = BinType::CategoricalBin;
      }
Guolin Ke's avatar
Guolin Ke committed
898
      bin_mappers[i].reset(new BinMapper());
Belinda Trotta's avatar
Belinda Trotta committed
899
900
      if (config_.max_bin_by_feature.empty()) {
        bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
901
                                sample_data.size(), config_.max_bin, config_.min_data_in_bin,
902
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
903
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
904
905
      } else {
        bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
906
                                sample_data.size(), config_.max_bin_by_feature[i],
907
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
908
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
909
      }
910
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
911
    }
912
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
913
914
  } else {
    // start and len will store the process feature indices for different machines
915
    // machine i will find bins for features in [ start[i], start[i] + len[i] )
Guolin Ke's avatar
Guolin Ke committed
916
917
    std::vector<int> start(num_machines);
    std::vector<int> len(num_machines);
918
    int step = (dataset->num_total_features_ + num_machines - 1) / num_machines;
Guolin Ke's avatar
Guolin Ke committed
919
920
921
922
    if (step < 1) { step = 1; }

    start[0] = 0;
    for (int i = 0; i < num_machines - 1; ++i) {
923
      len[i] = std::min(step, dataset->num_total_features_ - start[i]);
Guolin Ke's avatar
Guolin Ke committed
924
925
      start[i + 1] = start[i] + len[i];
    }
926
    len[num_machines - 1] = dataset->num_total_features_ - start[num_machines - 1];
927
    OMP_INIT_EX();
928
    #pragma omp parallel for schedule(guided)
929
    for (int i = 0; i < len[rank]; ++i) {
930
      OMP_LOOP_EX_BEGIN();
931
932
933
934
935
936
937
938
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(start[rank] + i)) {
        bin_type = BinType::CategoricalBin;
      }
      bin_mappers[i].reset(new BinMapper());
Nikita Titov's avatar
Nikita Titov committed
939
      if (static_cast<int>(sample_values.size()) <= start[rank] + i) {
940
941
        continue;
      }
Belinda Trotta's avatar
Belinda Trotta committed
942
      if (config_.max_bin_by_feature.empty()) {
943
        bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
Belinda Trotta's avatar
Belinda Trotta committed
944
                                static_cast<int>(sample_values[start[rank] + i].size()),
945
                                sample_data.size(), config_.max_bin, config_.min_data_in_bin,
946
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
947
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
948
      } else {
949
        bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
Belinda Trotta's avatar
Belinda Trotta committed
950
                                static_cast<int>(sample_values[start[rank] + i].size()),
951
                                sample_data.size(), config_.max_bin_by_feature[i],
952
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type,
953
                                config_.use_missing, config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
954
      }
955
      OMP_LOOP_EX_END();
956
    }
957
    OMP_THROW_EX();
958
    comm_size_t self_buf_size = 0;
Guolin Ke's avatar
Guolin Ke committed
959
    for (int i = 0; i < len[rank]; ++i) {
960
961
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
Guolin Ke's avatar
Guolin Ke committed
962
      }
963
      self_buf_size += static_cast<comm_size_t>(bin_mappers[i]->SizesInByte());
Guolin Ke's avatar
Guolin Ke committed
964
    }
965
966
    std::vector<char> input_buffer(self_buf_size);
    auto cp_ptr = input_buffer.data();
Guolin Ke's avatar
Guolin Ke committed
967
    for (int i = 0; i < len[rank]; ++i) {
Guolin Ke's avatar
Guolin Ke committed
968
969
970
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
971
972
      bin_mappers[i]->CopyTo(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
973
974
      // free
      bin_mappers[i].reset(nullptr);
Guolin Ke's avatar
Guolin Ke committed
975
    }
976
977
978
979
    std::vector<comm_size_t> size_len = Network::GlobalArray(self_buf_size);
    std::vector<comm_size_t> size_start(num_machines, 0);
    for (int i = 1; i < num_machines; ++i) {
      size_start[i] = size_start[i - 1] + size_len[i - 1];
Guolin Ke's avatar
Guolin Ke committed
980
    }
981
982
    comm_size_t total_buffer_size = size_start[num_machines - 1] + size_len[num_machines - 1];
    std::vector<char> output_buffer(total_buffer_size);
Guolin Ke's avatar
Guolin Ke committed
983
    // gather global feature bin mappers
984
985
    Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), total_buffer_size);
    cp_ptr = output_buffer.data();
Guolin Ke's avatar
Guolin Ke committed
986
    // restore features bins from buffer
987
    for (int i = 0; i < dataset->num_total_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
988
      if (ignore_features_.count(i) > 0) {
Guolin Ke's avatar
Guolin Ke committed
989
        bin_mappers[i] = nullptr;
Guolin Ke's avatar
Guolin Ke committed
990
991
        continue;
      }
Guolin Ke's avatar
Guolin Ke committed
992
      bin_mappers[i].reset(new BinMapper());
993
994
      bin_mappers[i]->CopyFrom(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
Guolin Ke's avatar
Guolin Ke committed
995
996
    }
  }
997
  dataset->Construct(&bin_mappers, dataset->num_total_features_, forced_bin_bounds, Common::Vector2Ptr<int>(&sample_indices).data(),
Guolin Ke's avatar
Guolin Ke committed
998
                     Common::Vector2Ptr<double>(&sample_values).data(),
999
                     Common::VectorSize<int>(sample_indices).data(), static_cast<int>(sample_indices.size()), sample_data.size(), config_);
Guolin Ke's avatar
Guolin Ke committed
1000
1001
1002
}

/*! \brief Extract local features from memory */
Guolin Ke's avatar
Guolin Ke committed
1003
void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>* text_data, const Parser* parser, Dataset* dataset) {
Guolin Ke's avatar
Guolin Ke committed
1004
1005
  std::vector<std::pair<int, double>> oneline_features;
  double tmp_label = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
1006
  auto& ref_text_data = *text_data;
Guolin Ke's avatar
Guolin Ke committed
1007
  if (predict_fun_ == nullptr) {
1008
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1009
    // if doesn't need to prediction with initial model
1010
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
Guolin Ke's avatar
Guolin Ke committed
1011
    for (data_size_t i = 0; i < dataset->num_data_; ++i) {
1012
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1013
1014
1015
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
Guolin Ke's avatar
Guolin Ke committed
1016
      parser->ParseOneLine(ref_text_data[i].c_str(), &oneline_features, &tmp_label);
Guolin Ke's avatar
Guolin Ke committed
1017
      // set label
1018
      dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1019
      // free processed line:
Guolin Ke's avatar
Guolin Ke committed
1020
      ref_text_data[i].clear();
Guolin Ke's avatar
Guolin Ke committed
1021
1022
      // shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
      // text_reader_->Lines()[i].shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
1023
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1024
1025
      // push data
      for (auto& inner_data : oneline_features) {
1026
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1027
1028
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1029
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1030
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1031
1032
1033
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
          dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
Guolin Ke's avatar
Guolin Ke committed
1034
1035
        } else {
          if (inner_data.first == weight_idx_) {
1036
            dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1037
1038
1039
1040
1041
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1042
      dataset->FinishOneRow(tid, i, is_feature_added);
1043
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1044
    }
1045
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1046
  } else {
1047
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1048
    // if need to prediction with initial model
1049
    std::vector<double> init_score(dataset->num_data_ * num_class_);
1050
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
Guolin Ke's avatar
Guolin Ke committed
1051
    for (data_size_t i = 0; i < dataset->num_data_; ++i) {
1052
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1053
1054
1055
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
Guolin Ke's avatar
Guolin Ke committed
1056
      parser->ParseOneLine(ref_text_data[i].c_str(), &oneline_features, &tmp_label);
Guolin Ke's avatar
Guolin Ke committed
1057
      // set initial score
Guolin Ke's avatar
Guolin Ke committed
1058
1059
      std::vector<double> oneline_init_score(num_class_);
      predict_fun_(oneline_features, oneline_init_score.data());
1060
      for (int k = 0; k < num_class_; ++k) {
1061
        init_score[k * dataset->num_data_ + i] = static_cast<double>(oneline_init_score[k]);
Guolin Ke's avatar
Guolin Ke committed
1062
1063
      }
      // set label
1064
      dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1065
1066
1067
1068
1069
      // free processed line:
      text_data[i].clear();
      // shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
      // text_reader_->Lines()[i].shrink_to_fit();
      // push data
Guolin Ke's avatar
Guolin Ke committed
1070
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1071
      for (auto& inner_data : oneline_features) {
1072
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1073
1074
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1075
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1076
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1077
1078
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
1079
          dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
Guolin Ke's avatar
Guolin Ke committed
1080
1081
        } else {
          if (inner_data.first == weight_idx_) {
1082
            dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1083
1084
1085
1086
1087
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1088
      dataset->FinishOneRow(tid, i, is_feature_added);
1089
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1090
    }
1091
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1092
    // metadata_ will manage space of init_score
1093
    dataset->metadata_.SetInitScore(init_score.data(), dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1094
  }
Guolin Ke's avatar
Guolin Ke committed
1095
  dataset->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1096
  // text data can be free after loaded feature values
Guolin Ke's avatar
Guolin Ke committed
1097
  text_data->clear();
Guolin Ke's avatar
Guolin Ke committed
1098
1099
1100
}

/*! \brief Extract local features from file */
1101
1102
void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser* parser,
                                            const std::vector<data_size_t>& used_data_indices, Dataset* dataset) {
1103
  std::vector<double> init_score;
Guolin Ke's avatar
Guolin Ke committed
1104
  if (predict_fun_ != nullptr) {
1105
    init_score = std::vector<double>(dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1106
1107
1108
1109
1110
1111
  }
  std::function<void(data_size_t, const std::vector<std::string>&)> process_fun =
    [this, &init_score, &parser, &dataset]
  (data_size_t start_idx, const std::vector<std::string>& lines) {
    std::vector<std::pair<int, double>> oneline_features;
    double tmp_label = 0.0f;
1112
    OMP_INIT_EX();
1113
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
Guolin Ke's avatar
Guolin Ke committed
1114
    for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) {
1115
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1116
1117
1118
1119
1120
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
      parser->ParseOneLine(lines[i].c_str(), &oneline_features, &tmp_label);
      // set initial score
Guolin Ke's avatar
Guolin Ke committed
1121
      if (!init_score.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1122
1123
        std::vector<double> oneline_init_score(num_class_);
        predict_fun_(oneline_features, oneline_init_score.data());
1124
        for (int k = 0; k < num_class_; ++k) {
1125
          init_score[k * dataset->num_data_ + start_idx + i] = static_cast<double>(oneline_init_score[k]);
Guolin Ke's avatar
Guolin Ke committed
1126
1127
1128
        }
      }
      // set label
1129
      dataset->metadata_.SetLabelAt(start_idx + i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1130
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1131
1132
      // push data
      for (auto& inner_data : oneline_features) {
1133
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1134
1135
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1136
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1137
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1138
1139
1140
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
          dataset->feature_groups_[group]->PushData(tid, sub_feature, start_idx + i, inner_data.second);
Guolin Ke's avatar
Guolin Ke committed
1141
1142
        } else {
          if (inner_data.first == weight_idx_) {
1143
            dataset->metadata_.SetWeightAt(start_idx + i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1144
1145
1146
1147
1148
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(start_idx + i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1149
      dataset->FinishOneRow(tid, i, is_feature_added);
1150
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1151
    }
1152
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1153
  };
1154
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
1155
  if (!used_data_indices.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1156
1157
1158
1159
1160
1161
1162
1163
    // only need part of data
    text_reader.ReadPartAndProcessParallel(used_data_indices, process_fun);
  } else {
    // need full data
    text_reader.ReadAllAndProcessParallel(process_fun);
  }

  // metadata_ will manage space of init_score
Guolin Ke's avatar
Guolin Ke committed
1164
  if (!init_score.empty()) {
1165
    dataset->metadata_.SetInitScore(init_score.data(), dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1166
  }
Guolin Ke's avatar
Guolin Ke committed
1167
  dataset->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1168
1169
1170
}

/*! \brief Check can load from binary file */
1171
std::string DatasetLoader::CheckCanLoadFromBin(const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
1172
1173
1174
  std::string bin_filename(filename);
  bin_filename.append(".bin");

1175
  auto reader = VirtualFileReader::Make(bin_filename.c_str());
Guolin Ke's avatar
Guolin Ke committed
1176

1177
  if (!reader->Init()) {
1178
    bin_filename = std::string(filename);
1179
1180
    reader = VirtualFileReader::Make(bin_filename.c_str());
    if (!reader->Init()) {
1181
      Log::Fatal("Cannot open data file %s", bin_filename.c_str());
1182
    }
1183
  }
1184
1185
1186
1187
1188

  size_t buffer_size = 256;
  auto buffer = std::vector<char>(buffer_size);
  // read size of token
  size_t size_of_token = std::strlen(Dataset::binary_file_token);
1189
  size_t read_cnt = reader->Read(buffer.data(), size_of_token);
1190
1191
  if (read_cnt == size_of_token
      && std::string(buffer.data()) == std::string(Dataset::binary_file_token)) {
1192
    return bin_filename;
Guolin Ke's avatar
Guolin Ke committed
1193
  } else {
1194
    return std::string();
Guolin Ke's avatar
Guolin Ke committed
1195
1196
1197
  }
}

1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235


std::vector<std::vector<double>> DatasetLoader::GetForcedBins(std::string forced_bins_path, int num_total_features,
                                                              const std::unordered_set<int>& categorical_features) {
  std::vector<std::vector<double>> forced_bins(num_total_features, std::vector<double>());
  if (forced_bins_path != "") {
    std::ifstream forced_bins_stream(forced_bins_path.c_str());
    if (forced_bins_stream.fail()) {
      Log::Warning("Could not open %s. Will ignore.", forced_bins_path.c_str());
    } else {
      std::stringstream buffer;
      buffer << forced_bins_stream.rdbuf();
      std::string err;
      Json forced_bins_json = Json::parse(buffer.str(), err);
      CHECK(forced_bins_json.is_array());
      std::vector<Json> forced_bins_arr = forced_bins_json.array_items();
      for (size_t i = 0; i < forced_bins_arr.size(); ++i) {
        int feature_num = forced_bins_arr[i]["feature"].int_value();
        CHECK(feature_num < num_total_features);
        if (categorical_features.count(feature_num)) {
          Log::Warning("Feature %d is categorical. Will ignore forced bins for this  feature.", feature_num);
        } else {
          std::vector<Json> bounds_arr = forced_bins_arr[i]["bin_upper_bound"].array_items();
          for (size_t j = 0; j < bounds_arr.size(); ++j) {
            forced_bins[feature_num].push_back(bounds_arr[j].number_value());
          }
        }
      }
      // remove duplicates
      for (int i = 0; i < num_total_features; ++i) {
        auto new_end = std::unique(forced_bins[i].begin(), forced_bins[i].end());
        forced_bins[i].erase(new_end, forced_bins[i].end());
      }
    }
  }
  return forced_bins;
}

1236
}  // namespace LightGBM