dataset_loader.cpp 52.9 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
#include <LightGBM/dataset_loader.h>

Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/network.h>
8
9
10
#include <LightGBM/utils/array_args.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
11

12
13
#include <fstream>

14
15
#include <LightGBM/json11.hpp>

16
17
using namespace json11;

Guolin Ke's avatar
Guolin Ke committed
18
19
namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
20
21
DatasetLoader::DatasetLoader(const Config& io_config, const PredictFunction& predict_fun, int num_class, const char* filename)
  :config_(io_config), random_(config_.data_random_seed), predict_fun_(predict_fun), num_class_(num_class) {
Guolin Ke's avatar
Guolin Ke committed
22
23
24
25
  label_idx_ = 0;
  weight_idx_ = NO_SPECIFIC;
  group_idx_ = NO_SPECIFIC;
  SetHeader(filename);
Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
30
}

DatasetLoader::~DatasetLoader() {
}

Guolin Ke's avatar
Guolin Ke committed
31
void DatasetLoader::SetHeader(const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
32
  std::unordered_map<std::string, int> name2idx;
Guolin Ke's avatar
Guolin Ke committed
33
34
  std::string name_prefix("name:");
  if (filename != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
35
    TextReader<data_size_t> text_reader(filename, config_.header);
Guolin Ke's avatar
Guolin Ke committed
36

Guolin Ke's avatar
Guolin Ke committed
37
    // get column names
Guolin Ke's avatar
Guolin Ke committed
38
    if (config_.header) {
Guolin Ke's avatar
Guolin Ke committed
39
      std::string first_line = text_reader.first_line();
40
      feature_names_ = Common::Split(first_line.c_str(), "\t,");
Guolin Ke's avatar
Guolin Ke committed
41
42
    }

Guolin Ke's avatar
Guolin Ke committed
43
    // load label idx first
Guolin Ke's avatar
Guolin Ke committed
44
45
46
    if (config_.label_column.size() > 0) {
      if (Common::StartsWith(config_.label_column, name_prefix)) {
        std::string name = config_.label_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
47
48
49
50
51
52
53
54
55
56
        label_idx_ = -1;
        for (int i = 0; i < static_cast<int>(feature_names_.size()); ++i) {
          if (name == feature_names_[i]) {
            label_idx_ = i;
            break;
          }
        }
        if (label_idx_ >= 0) {
          Log::Info("Using column %s as label", name.c_str());
        } else {
57
58
          Log::Fatal("Could not find label column %s in data file \n"
                     "or data file doesn't contain header", name.c_str());
Guolin Ke's avatar
Guolin Ke committed
59
        }
Guolin Ke's avatar
Guolin Ke committed
60
      } else {
Guolin Ke's avatar
Guolin Ke committed
61
        if (!Common::AtoiAndCheck(config_.label_column.c_str(), &label_idx_)) {
62
63
64
          Log::Fatal("label_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
65
66
        }
        Log::Info("Using column number %d as label", label_idx_);
Guolin Ke's avatar
Guolin Ke committed
67
68
      }
    }
Guolin Ke's avatar
Guolin Ke committed
69

Guolin Ke's avatar
Guolin Ke committed
70
    if (!feature_names_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
71
72
73
74
      // erase label column name
      feature_names_.erase(feature_names_.begin() + label_idx_);
      for (size_t i = 0; i < feature_names_.size(); ++i) {
        name2idx[feature_names_[i]] = static_cast<int>(i);
Guolin Ke's avatar
Guolin Ke committed
75
      }
Guolin Ke's avatar
Guolin Ke committed
76
77
78
    }

    // load ignore columns
Guolin Ke's avatar
Guolin Ke committed
79
80
81
    if (config_.ignore_column.size() > 0) {
      if (Common::StartsWith(config_.ignore_column, name_prefix)) {
        std::string names = config_.ignore_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
82
83
84
85
86
87
88
89
90
        for (auto name : Common::Split(names.c_str(), ',')) {
          if (name2idx.count(name) > 0) {
            int tmp = name2idx[name];
            ignore_features_.emplace(tmp);
          } else {
            Log::Fatal("Could not find ignore column %s in data file", name.c_str());
          }
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
91
        for (auto token : Common::Split(config_.ignore_column.c_str(), ',')) {
Guolin Ke's avatar
Guolin Ke committed
92
93
          int tmp = 0;
          if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
94
95
96
            Log::Fatal("ignore_column is not a number,\n"
                       "if you want to use a column name,\n"
                       "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
97
98
          }
          ignore_features_.emplace(tmp);
Guolin Ke's avatar
Guolin Ke committed
99
100
101
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
102
    // load weight idx
Guolin Ke's avatar
Guolin Ke committed
103
104
105
    if (config_.weight_column.size() > 0) {
      if (Common::StartsWith(config_.weight_column, name_prefix)) {
        std::string name = config_.weight_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
106
107
108
109
110
111
        if (name2idx.count(name) > 0) {
          weight_idx_ = name2idx[name];
          Log::Info("Using column %s as weight", name.c_str());
        } else {
          Log::Fatal("Could not find weight column %s in data file", name.c_str());
        }
Guolin Ke's avatar
Guolin Ke committed
112
      } else {
Guolin Ke's avatar
Guolin Ke committed
113
        if (!Common::AtoiAndCheck(config_.weight_column.c_str(), &weight_idx_)) {
114
115
116
          Log::Fatal("weight_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
117
118
        }
        Log::Info("Using column number %d as weight", weight_idx_);
Guolin Ke's avatar
Guolin Ke committed
119
      }
Guolin Ke's avatar
Guolin Ke committed
120
      ignore_features_.emplace(weight_idx_);
Guolin Ke's avatar
Guolin Ke committed
121
    }
Guolin Ke's avatar
Guolin Ke committed
122
    // load group idx
Guolin Ke's avatar
Guolin Ke committed
123
124
125
    if (config_.group_column.size() > 0) {
      if (Common::StartsWith(config_.group_column, name_prefix)) {
        std::string name = config_.group_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
126
127
128
129
130
131
132
        if (name2idx.count(name) > 0) {
          group_idx_ = name2idx[name];
          Log::Info("Using column %s as group/query id", name.c_str());
        } else {
          Log::Fatal("Could not find group/query column %s in data file", name.c_str());
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
133
        if (!Common::AtoiAndCheck(config_.group_column.c_str(), &group_idx_)) {
134
135
136
          Log::Fatal("group_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
137
138
139
140
        }
        Log::Info("Using column number %d as group/query id", group_idx_);
      }
      ignore_features_.emplace(group_idx_);
Guolin Ke's avatar
Guolin Ke committed
141
142
    }
  }
Guolin Ke's avatar
Guolin Ke committed
143
144
145
  if (config_.categorical_feature.size() > 0) {
    if (Common::StartsWith(config_.categorical_feature, name_prefix)) {
      std::string names = config_.categorical_feature.substr(name_prefix.size());
146
147
148
149
150
      for (auto name : Common::Split(names.c_str(), ',')) {
        if (name2idx.count(name) > 0) {
          int tmp = name2idx[name];
          categorical_features_.emplace(tmp);
        } else {
Guolin Ke's avatar
Guolin Ke committed
151
          Log::Fatal("Could not find categorical_feature %s in data file", name.c_str());
152
153
154
        }
      }
    } else {
Guolin Ke's avatar
Guolin Ke committed
155
      for (auto token : Common::Split(config_.categorical_feature.c_str(), ',')) {
156
157
        int tmp = 0;
        if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
Guolin Ke's avatar
Guolin Ke committed
158
          Log::Fatal("categorical_feature is not a number,\n"
159
160
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
161
162
163
164
165
        }
        categorical_features_.emplace(tmp);
      }
    }
  }
Guolin Ke's avatar
Guolin Ke committed
166
167
}

168
Dataset* DatasetLoader::LoadFromFile(const char* filename, const char* initscore_file, int rank, int num_machines) {
Guolin Ke's avatar
Guolin Ke committed
169
  // don't support query id in data file when training in parallel
Guolin Ke's avatar
Guolin Ke committed
170
  if (num_machines > 1 && !config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
171
    if (group_idx_ > 0) {
172
173
      Log::Fatal("Using a query id without pre-partitioning the data file is not supported for parallel training.\n"
                 "Please use an additional query file or pre-partition the data");
Guolin Ke's avatar
Guolin Ke committed
174
175
    }
  }
Guolin Ke's avatar
Guolin Ke committed
176
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
Guolin Ke's avatar
Guolin Ke committed
177
178
  data_size_t num_global_data = 0;
  std::vector<data_size_t> used_data_indices;
179
180
  auto bin_filename = CheckCanLoadFromBin(filename);
  if (bin_filename.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
181
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_));
Guolin Ke's avatar
Guolin Ke committed
182
183
184
185
    if (parser == nullptr) {
      Log::Fatal("Could not recognize data format of %s", filename);
    }
    dataset->data_filename_ = filename;
Guolin Ke's avatar
Guolin Ke committed
186
    dataset->label_idx_ = label_idx_;
187
    dataset->metadata_.Init(filename, initscore_file);
Guolin Ke's avatar
Guolin Ke committed
188
    if (!config_.two_round) {
Guolin Ke's avatar
Guolin Ke committed
189
      // read data to memory
190
      auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
191
192
193
194
      dataset->num_data_ = static_cast<data_size_t>(text_data.size());
      // sample data
      auto sample_data = SampleTextDataFromMemory(text_data);
      // construct feature bin mappers
Guolin Ke's avatar
Guolin Ke committed
195
      ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
196
      // initialize label
197
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
198
      // extract features
Guolin Ke's avatar
Guolin Ke committed
199
      ExtractFeaturesFromMemory(&text_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
200
201
202
203
204
205
206
207
208
209
      text_data.clear();
    } else {
      // sample data from file
      auto sample_data = SampleTextDataFromFile(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
      if (used_data_indices.size() > 0) {
        dataset->num_data_ = static_cast<data_size_t>(used_data_indices.size());
      } else {
        dataset->num_data_ = num_global_data;
      }
      // construct feature bin mappers
Guolin Ke's avatar
Guolin Ke committed
210
      ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
211
      // initialize label
212
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
213
      // extract features
Guolin Ke's avatar
Guolin Ke committed
214
      ExtractFeaturesFromFile(filename, parser.get(), used_data_indices, dataset.get());
Guolin Ke's avatar
Guolin Ke committed
215
216
217
    }
  } else {
    // load data from binary file
218
    dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), rank, num_machines, &num_global_data, &used_data_indices));
Guolin Ke's avatar
Guolin Ke committed
219
220
221
222
  }
  // check meta data
  dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);
  // need to check training data
Guolin Ke's avatar
Guolin Ke committed
223
224
  CheckDataset(dataset.get());
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
225
226
227
228
}



229
Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, const char* initscore_file, const Dataset* train_data) {
Guolin Ke's avatar
Guolin Ke committed
230
231
  data_size_t num_global_data = 0;
  std::vector<data_size_t> used_data_indices;
Guolin Ke's avatar
Guolin Ke committed
232
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
233
234
  auto bin_filename = CheckCanLoadFromBin(filename);
  if (bin_filename.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
235
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_));
Guolin Ke's avatar
Guolin Ke committed
236
237
238
239
    if (parser == nullptr) {
      Log::Fatal("Could not recognize data format of %s", filename);
    }
    dataset->data_filename_ = filename;
Guolin Ke's avatar
Guolin Ke committed
240
    dataset->label_idx_ = label_idx_;
241
    dataset->metadata_.Init(filename, initscore_file);
Guolin Ke's avatar
Guolin Ke committed
242
    if (!config_.two_round) {
Guolin Ke's avatar
Guolin Ke committed
243
244
245
246
      // read data in memory
      auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, 0, 1, &num_global_data, &used_data_indices);
      dataset->num_data_ = static_cast<data_size_t>(text_data.size());
      // initialize label
247
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
248
      dataset->CreateValid(train_data);
Guolin Ke's avatar
Guolin Ke committed
249
      // extract features
Guolin Ke's avatar
Guolin Ke committed
250
      ExtractFeaturesFromMemory(&text_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
251
252
      text_data.clear();
    } else {
Guolin Ke's avatar
Guolin Ke committed
253
      TextReader<data_size_t> text_reader(filename, config_.header);
Guolin Ke's avatar
Guolin Ke committed
254
255
256
257
      // Get number of lines of data file
      dataset->num_data_ = static_cast<data_size_t>(text_reader.CountLine());
      num_global_data = dataset->num_data_;
      // initialize label
258
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
259
      dataset->CreateValid(train_data);
Guolin Ke's avatar
Guolin Ke committed
260
      // extract features
Guolin Ke's avatar
Guolin Ke committed
261
      ExtractFeaturesFromFile(filename, parser.get(), used_data_indices, dataset.get());
Guolin Ke's avatar
Guolin Ke committed
262
263
264
    }
  } else {
    // load data from binary file
265
    dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), 0, 1, &num_global_data, &used_data_indices));
Guolin Ke's avatar
Guolin Ke committed
266
267
268
269
  }
  // not need to check validation data
  // check meta data
  dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
270
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
271
272
}

273
274
275
Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* bin_filename,
                                        int rank, int num_machines, int* num_global_data,
                                        std::vector<data_size_t>* used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
276
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
277
  auto reader = VirtualFileReader::Make(bin_filename);
Guolin Ke's avatar
Guolin Ke committed
278
  dataset->data_filename_ = data_filename;
279
  if (!reader->Init()) {
Guolin Ke's avatar
Guolin Ke committed
280
281
282
283
284
    Log::Fatal("Could not read binary data from %s", bin_filename);
  }

  // buffer to read binary file
  size_t buffer_size = 16 * 1024 * 1024;
Guolin Ke's avatar
Guolin Ke committed
285
  auto buffer = std::vector<char>(buffer_size);
286

287
288
  // check token
  size_t size_of_token = std::strlen(Dataset::binary_file_token);
289
290
  size_t read_cnt = reader->Read(buffer.data(), sizeof(char) * size_of_token);
  if (read_cnt != sizeof(char) * size_of_token) {
291
292
293
    Log::Fatal("Binary file error: token has the wrong size");
  }
  if (std::string(buffer.data()) != std::string(Dataset::binary_file_token)) {
294
    Log::Fatal("Input file is not LightGBM binary file");
295
  }
Guolin Ke's avatar
Guolin Ke committed
296
297

  // read size of header
298
  read_cnt = reader->Read(buffer.data(), sizeof(size_t));
Guolin Ke's avatar
Guolin Ke committed
299

300
  if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
301
302
303
    Log::Fatal("Binary file error: header has the wrong size");
  }

Guolin Ke's avatar
Guolin Ke committed
304
  size_t size_of_head = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
305
306
307
308

  // re-allocmate space if not enough
  if (size_of_head > buffer_size) {
    buffer_size = size_of_head;
Guolin Ke's avatar
Guolin Ke committed
309
    buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
310
311
  }
  // read header
312
  read_cnt = reader->Read(buffer.data(), size_of_head);
Guolin Ke's avatar
Guolin Ke committed
313
314
315
316
317

  if (read_cnt != size_of_head) {
    Log::Fatal("Binary file error: header is incorrect");
  }
  // get header
Guolin Ke's avatar
Guolin Ke committed
318
  const char* mem_ptr = buffer.data();
Guolin Ke's avatar
Guolin Ke committed
319
320
321
322
323
324
  dataset->num_data_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
  mem_ptr += sizeof(dataset->num_data_);
  dataset->num_features_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->num_features_);
  dataset->num_total_features_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->num_total_features_);
Guolin Ke's avatar
Guolin Ke committed
325
326
  dataset->label_idx_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->label_idx_);
327
328
329
330
331
332
333
334
335
336
  dataset->max_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->max_bin_);
  dataset->bin_construct_sample_cnt_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->bin_construct_sample_cnt_);
  dataset->min_data_in_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->min_data_in_bin_);
  dataset->use_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
  mem_ptr += sizeof(dataset->use_missing_);
  dataset->zero_as_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
  mem_ptr += sizeof(dataset->zero_as_missing_);
Guolin Ke's avatar
Guolin Ke committed
337
338
  dataset->sparse_threshold_ = *(reinterpret_cast<const double*>(mem_ptr));
  mem_ptr += sizeof(dataset->sparse_threshold_);
Guolin Ke's avatar
Guolin Ke committed
339
340
  const int* tmp_feature_map = reinterpret_cast<const int*>(mem_ptr);
  dataset->used_feature_map_.clear();
Guolin Ke's avatar
Guolin Ke committed
341
  for (int i = 0; i < dataset->num_total_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
342
343
    dataset->used_feature_map_.push_back(tmp_feature_map[i]);
  }
Guolin Ke's avatar
Guolin Ke committed
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
  mem_ptr += sizeof(int) * dataset->num_total_features_;
  // num_groups
  dataset->num_groups_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->num_groups_);
  // real_feature_idx_
  const int* tmp_ptr_real_feature_idx_ = reinterpret_cast<const int*>(mem_ptr);
  dataset->real_feature_idx_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->real_feature_idx_.push_back(tmp_ptr_real_feature_idx_[i]);
  }
  mem_ptr += sizeof(int) * dataset->num_features_;
  // feature2group
  const int* tmp_ptr_feature2group = reinterpret_cast<const int*>(mem_ptr);
  dataset->feature2group_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->feature2group_.push_back(tmp_ptr_feature2group[i]);
  }
  mem_ptr += sizeof(int) * dataset->num_features_;
  // feature2subfeature
  const int* tmp_ptr_feature2subfeature = reinterpret_cast<const int*>(mem_ptr);
  dataset->feature2subfeature_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->feature2subfeature_.push_back(tmp_ptr_feature2subfeature[i]);
  }
  mem_ptr += sizeof(int) * dataset->num_features_;
  // group_bin_boundaries
  const uint64_t* tmp_ptr_group_bin_boundaries = reinterpret_cast<const uint64_t*>(mem_ptr);
  dataset->group_bin_boundaries_.clear();
  for (int i = 0; i < dataset->num_groups_ + 1; ++i) {
    dataset->group_bin_boundaries_.push_back(tmp_ptr_group_bin_boundaries[i]);
  }
  mem_ptr += sizeof(uint64_t) * (dataset->num_groups_ + 1);

  // group_feature_start_
  const int* tmp_ptr_group_feature_start = reinterpret_cast<const int*>(mem_ptr);
  dataset->group_feature_start_.clear();
380
  for (int i = 0; i < dataset->num_groups_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
381
382
383
384
385
386
387
388
389
390
391
392
    dataset->group_feature_start_.push_back(tmp_ptr_group_feature_start[i]);
  }
  mem_ptr += sizeof(int) * (dataset->num_groups_);

  // group_feature_cnt_
  const int* tmp_ptr_group_feature_cnt = reinterpret_cast<const int*>(mem_ptr);
  dataset->group_feature_cnt_.clear();
  for (int i = 0; i < dataset->num_groups_; ++i) {
    dataset->group_feature_cnt_.push_back(tmp_ptr_group_feature_cnt[i]);
  }
  mem_ptr += sizeof(int) * (dataset->num_groups_);

393
  if (!config_.monotone_constraints.empty()) {
394
    CHECK(static_cast<size_t>(dataset->num_total_features_) == config_.monotone_constraints.size());
395
    dataset->monotone_types_.resize(dataset->num_features_);
396
    for (int i = 0; i < dataset->num_total_features_; ++i) {
397
      int inner_fidx = dataset->InnerFeatureIndex(i);
398
      if (inner_fidx >= 0) {
399
400
401
        dataset->monotone_types_[inner_fidx] = config_.monotone_constraints[i];
      }
    }
402
  } else {
403
404
405
406
407
    const int8_t* tmp_ptr_monotone_type = reinterpret_cast<const int8_t*>(mem_ptr);
    dataset->monotone_types_.clear();
    for (int i = 0; i < dataset->num_features_; ++i) {
      dataset->monotone_types_.push_back(tmp_ptr_monotone_type[i]);
    }
Guolin Ke's avatar
Guolin Ke committed
408
409
410
411
412
413
414
  }
  mem_ptr += sizeof(int8_t) * (dataset->num_features_);

  if (ArrayArgs<int8_t>::CheckAllZero(dataset->monotone_types_)) {
    dataset->monotone_types_.clear();
  }

415
  if (!config_.feature_contri.empty()) {
416
    CHECK(static_cast<size_t>(dataset->num_total_features_) == config_.feature_contri.size());
417
    dataset->feature_penalty_.resize(dataset->num_features_);
418
    for (int i = 0; i < dataset->num_total_features_; ++i) {
419
      int inner_fidx = dataset->InnerFeatureIndex(i);
420
      if (inner_fidx >= 0) {
421
422
423
        dataset->feature_penalty_[inner_fidx] = config_.feature_contri[i];
      }
    }
424
  } else {
425
426
427
428
429
    const double* tmp_ptr_feature_penalty = reinterpret_cast<const double*>(mem_ptr);
    dataset->feature_penalty_.clear();
    for (int i = 0; i < dataset->num_features_; ++i) {
      dataset->feature_penalty_.push_back(tmp_ptr_feature_penalty[i]);
    }
Guolin Ke's avatar
Guolin Ke committed
430
431
432
433
434
435
436
  }
  mem_ptr += sizeof(double) * (dataset->num_features_);

  if (ArrayArgs<double>::CheckAll(dataset->feature_penalty_, 1)) {
    dataset->feature_penalty_.clear();
  }

Belinda Trotta's avatar
Belinda Trotta committed
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
  if (!config_.max_bin_by_feature.empty()) {
    CHECK(static_cast<size_t>(dataset->num_total_features_) == config_.max_bin_by_feature.size());
    CHECK(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())) > 1);
    dataset->max_bin_by_feature_.resize(dataset->num_total_features_);
    dataset->max_bin_by_feature_.assign(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end());
  } else {
    const int32_t* tmp_ptr_max_bin_by_feature = reinterpret_cast<const int32_t*>(mem_ptr);
    dataset->max_bin_by_feature_.clear();
    for (int i = 0; i < dataset->num_total_features_; ++i) {
      dataset->max_bin_by_feature_.push_back(tmp_ptr_max_bin_by_feature[i]);
    }
  }
  mem_ptr += sizeof(int32_t) * (dataset->num_total_features_);
  if (ArrayArgs<int32_t>::CheckAll(dataset->max_bin_by_feature_, -1)) {
    dataset->max_bin_by_feature_.clear();
  }

Guolin Ke's avatar
Guolin Ke committed
454
  // get feature names
Guolin Ke's avatar
Guolin Ke committed
455
  dataset->feature_names_.clear();
Guolin Ke's avatar
Guolin Ke committed
456
457
458
459
460
461
462
463
464
465
  // write feature names
  for (int i = 0; i < dataset->num_total_features_; ++i) {
    int str_len = *(reinterpret_cast<const int*>(mem_ptr));
    mem_ptr += sizeof(int);
    std::stringstream str_buf;
    for (int j = 0; j < str_len; ++j) {
      char tmp_char = *(reinterpret_cast<const char*>(mem_ptr));
      mem_ptr += sizeof(char);
      str_buf << tmp_char;
    }
Guolin Ke's avatar
Guolin Ke committed
466
    dataset->feature_names_.emplace_back(str_buf.str());
Guolin Ke's avatar
Guolin Ke committed
467
  }
468
469
470
471
472
473
474
475
476
477
478
479
480
  // get forced_bin_bounds_
  dataset->forced_bin_bounds_ = std::vector<std::vector<double>>(dataset->num_total_features_, std::vector<double>());
  for (int i = 0; i < dataset->num_total_features_; ++i) {
    int num_bounds = *(reinterpret_cast<const int*>(mem_ptr));
    mem_ptr += sizeof(int);
    dataset->forced_bin_bounds_[i] = std::vector<double>();
    const double* tmp_ptr_forced_bounds = reinterpret_cast<const double*>(mem_ptr);
    for (int j = 0; j < num_bounds; ++j) {
      double bound = tmp_ptr_forced_bounds[j];
      dataset->forced_bin_bounds_[i].push_back(bound);
    }
    mem_ptr += num_bounds * sizeof(double);
  }
Guolin Ke's avatar
Guolin Ke committed
481
482

  // read size of meta data
483
  read_cnt = reader->Read(buffer.data(), sizeof(size_t));
Guolin Ke's avatar
Guolin Ke committed
484

485
  if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
486
487
488
    Log::Fatal("Binary file error: meta data has the wrong size");
  }

Guolin Ke's avatar
Guolin Ke committed
489
  size_t size_of_metadata = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
490
491
492
493

  // re-allocate space if not enough
  if (size_of_metadata > buffer_size) {
    buffer_size = size_of_metadata;
Guolin Ke's avatar
Guolin Ke committed
494
    buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
495
496
  }
  //  read meta data
497
  read_cnt = reader->Read(buffer.data(), size_of_metadata);
Guolin Ke's avatar
Guolin Ke committed
498
499
500
501
502

  if (read_cnt != size_of_metadata) {
    Log::Fatal("Binary file error: meta data is incorrect");
  }
  // load meta data
Guolin Ke's avatar
Guolin Ke committed
503
  dataset->metadata_.LoadFromMemory(buffer.data());
Guolin Ke's avatar
Guolin Ke committed
504

505
506
  *num_global_data = dataset->num_data_;
  used_data_indices->clear();
Guolin Ke's avatar
Guolin Ke committed
507
  // sample local used data if need to partition
Guolin Ke's avatar
Guolin Ke committed
508
  if (num_machines > 1 && !config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
509
510
511
512
    const data_size_t* query_boundaries = dataset->metadata_.query_boundaries();
    if (query_boundaries == nullptr) {
      // if not contain query file, minimal sample unit is one record
      for (data_size_t i = 0; i < dataset->num_data_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
513
        if (random_.NextShort(0, num_machines) == rank) {
514
          used_data_indices->push_back(i);
Guolin Ke's avatar
Guolin Ke committed
515
516
517
518
519
520
521
522
523
        }
      }
    } else {
      // if contain query file, minimal sample unit is one query
      data_size_t num_queries = dataset->metadata_.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      for (data_size_t i = 0; i < dataset->num_data_; ++i) {
        if (qid >= num_queries) {
524
525
          Log::Fatal("Current query exceeds the range of the query file,\n"
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
526
527
528
529
        }
        if (i >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
530
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
531
532
533
534
535
            is_query_used = true;
          }
          ++qid;
        }
        if (is_query_used) {
536
          used_data_indices->push_back(i);
Guolin Ke's avatar
Guolin Ke committed
537
538
539
        }
      }
    }
540
    dataset->num_data_ = static_cast<data_size_t>((*used_data_indices).size());
Guolin Ke's avatar
Guolin Ke committed
541
  }
542
  dataset->metadata_.PartitionLabel(*used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
543
  // read feature data
Guolin Ke's avatar
Guolin Ke committed
544
  for (int i = 0; i < dataset->num_groups_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
545
    // read feature size
546
547
    read_cnt = reader->Read(buffer.data(), sizeof(size_t));
    if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
548
549
      Log::Fatal("Binary file error: feature %d has the wrong size", i);
    }
Guolin Ke's avatar
Guolin Ke committed
550
    size_t size_of_feature = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
551
552
553
    // re-allocate space if not enough
    if (size_of_feature > buffer_size) {
      buffer_size = size_of_feature;
Guolin Ke's avatar
Guolin Ke committed
554
      buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
555
556
    }

557
    read_cnt = reader->Read(buffer.data(), size_of_feature);
Guolin Ke's avatar
Guolin Ke committed
558
559
560
561

    if (read_cnt != size_of_feature) {
      Log::Fatal("Binary file error: feature %d is incorrect, read count: %d", i, read_cnt);
    }
Guolin Ke's avatar
Guolin Ke committed
562
    dataset->feature_groups_.emplace_back(std::unique_ptr<FeatureGroup>(
563
564
      new FeatureGroup(buffer.data(),
                       *num_global_data,
565
                       *used_data_indices)));
Guolin Ke's avatar
Guolin Ke committed
566
  }
Guolin Ke's avatar
Guolin Ke committed
567
  dataset->feature_groups_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
568
  dataset->is_finish_load_ = true;
Guolin Ke's avatar
Guolin Ke committed
569
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
570
571
}

572

573
574
575
Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
                                               int** sample_indices, int num_col, const int* num_per_col,
                                               size_t total_sample_size, data_size_t num_data) {
576
577
578
579
580
  int num_total_features = num_col;
  if (Network::num_machines() > 1) {
    num_total_features = Network::GlobalSyncUpByMax(num_total_features);
  }
  std::vector<std::unique_ptr<BinMapper>> bin_mappers(num_total_features);
581
582
  // fill feature_names_ if not header
  if (feature_names_.empty()) {
583
    for (int i = 0; i < num_col; ++i) {
584
585
586
587
588
      std::stringstream str_buf;
      str_buf << "Column_" << i;
      feature_names_.push_back(str_buf.str());
    }
  }
Belinda Trotta's avatar
Belinda Trotta committed
589
590
591
592
  if (!config_.max_bin_by_feature.empty()) {
    CHECK(static_cast<size_t>(num_col) == config_.max_bin_by_feature.size());
    CHECK(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())) > 1);
  }
593
594
595
596
597

  // get forced split
  std::string forced_bins_path = config_.forcedbins_filename;
  std::vector<std::vector<double>> forced_bin_bounds = DatasetLoader::GetForcedBins(forced_bins_path, num_col, categorical_features_);

Guolin Ke's avatar
Guolin Ke committed
598
  const data_size_t filter_cnt = static_cast<data_size_t>(
Guolin Ke's avatar
Guolin Ke committed
599
    static_cast<double>(config_.min_data_in_leaf * total_sample_size) / num_data);
600
601
602
603
604
605
606
607
608
609
610
611
612
  if (Network::num_machines() == 1) {
    // if only one machine, find bin locally
    OMP_INIT_EX();
    #pragma omp parallel for schedule(guided)
    for (int i = 0; i < num_col; ++i) {
      OMP_LOOP_EX_BEGIN();
      if (ignore_features_.count(i) > 0) {
        bin_mappers[i] = nullptr;
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(i)) {
        bin_type = BinType::CategoricalBin;
613
614
615
616
        bool feat_is_unconstrained = ((config_.monotone_constraints.size() == 0) || (config_.monotone_constraints[i] == 0));
        if (!feat_is_unconstrained) {
            Log::Fatal("The output cannot be monotone with respect to categorical features");
        }
617
618
      }
      bin_mappers[i].reset(new BinMapper());
Belinda Trotta's avatar
Belinda Trotta committed
619
620
      if (config_.max_bin_by_feature.empty()) {
        bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
621
                                config_.max_bin, config_.min_data_in_bin, filter_cnt,
622
623
                                bin_type, config_.use_missing, config_.zero_as_missing,
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
624
625
      } else {
        bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
626
627
                                config_.max_bin_by_feature[i], config_.min_data_in_bin,
                                filter_cnt, bin_type, config_.use_missing,
628
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
629
      }
630
631
632
633
634
635
636
637
638
639
640
641
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();
  } else {
    // if have multi-machines, need to find bin distributed
    // different machines will find bin for different features
    int num_machines = Network::num_machines();
    int rank = Network::rank();
    // start and len will store the process feature indices for different machines
    // machine i will find bins for features in [ start[i], start[i] + len[i] )
    std::vector<int> start(num_machines);
    std::vector<int> len(num_machines);
642
    int step = (num_total_features + num_machines - 1) / num_machines;
643
644
645
646
    if (step < 1) { step = 1; }

    start[0] = 0;
    for (int i = 0; i < num_machines - 1; ++i) {
647
      len[i] = std::min(step, num_total_features - start[i]);
648
649
      start[i + 1] = start[i] + len[i];
    }
650
    len[num_machines - 1] = num_total_features - start[num_machines - 1];
651
652
653
654
655
656
657
658
659
660
661
662
    OMP_INIT_EX();
    #pragma omp parallel for schedule(guided)
    for (int i = 0; i < len[rank]; ++i) {
      OMP_LOOP_EX_BEGIN();
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(start[rank] + i)) {
        bin_type = BinType::CategoricalBin;
      }
      bin_mappers[i].reset(new BinMapper());
663
664
665
      if (num_col <= start[rank] + i) {
        continue;
      }
Belinda Trotta's avatar
Belinda Trotta committed
666
      if (config_.max_bin_by_feature.empty()) {
667
668
        bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
                                total_sample_size, config_.max_bin, config_.min_data_in_bin,
669
                                filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing,
670
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
671
      } else {
672
673
674
        bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
                                total_sample_size, config_.max_bin_by_feature[start[rank] + i],
                                config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing,
675
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
676
      }
677
678
      OMP_LOOP_EX_END();
    }
679
    comm_size_t self_buf_size = 0;
680
    for (int i = 0; i < len[rank]; ++i) {
681
682
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
683
      }
684
      self_buf_size += static_cast<comm_size_t>(bin_mappers[i]->SizesInByte());
Guolin Ke's avatar
Guolin Ke committed
685
    }
686
687
    std::vector<char> input_buffer(self_buf_size);
    auto cp_ptr = input_buffer.data();
688
689
690
691
    for (int i = 0; i < len[rank]; ++i) {
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
692
693
      bin_mappers[i]->CopyTo(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
694
695
696
      // free
      bin_mappers[i].reset(nullptr);
    }
697
698
699
700
    std::vector<comm_size_t> size_len = Network::GlobalArray(self_buf_size);
    std::vector<comm_size_t> size_start(num_machines, 0);
    for (int i = 1; i < num_machines; ++i) {
      size_start[i] = size_start[i - 1] + size_len[i - 1];
701
    }
702
703
    comm_size_t total_buffer_size = size_start[num_machines - 1] + size_len[num_machines - 1];
    std::vector<char> output_buffer(total_buffer_size);
704
    // gather global feature bin mappers
705
706
    Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), total_buffer_size);
    cp_ptr = output_buffer.data();
707
    // restore features bins from buffer
708
    for (int i = 0; i < num_total_features; ++i) {
709
710
711
712
713
      if (ignore_features_.count(i) > 0) {
        bin_mappers[i] = nullptr;
        continue;
      }
      bin_mappers[i].reset(new BinMapper());
714
715
      bin_mappers[i]->CopyFrom(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
716
    }
Guolin Ke's avatar
Guolin Ke committed
717
  }
Guolin Ke's avatar
Guolin Ke committed
718
  auto dataset = std::unique_ptr<Dataset>(new Dataset(num_data));
719
  dataset->Construct(&bin_mappers, num_total_features, forced_bin_bounds, sample_indices, num_per_col, num_col, total_sample_size, config_);
720
  dataset->set_feature_names(feature_names_);
Guolin Ke's avatar
Guolin Ke committed
721
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
722
}
Guolin Ke's avatar
Guolin Ke committed
723
724
725
726
727
728


// ---- private functions ----

void DatasetLoader::CheckDataset(const Dataset* dataset) {
  if (dataset->num_data_ <= 0) {
Guolin Ke's avatar
Guolin Ke committed
729
    Log::Fatal("Data file %s is empty", dataset->data_filename_.c_str());
Guolin Ke's avatar
Guolin Ke committed
730
  }
731
732
  if (dataset->feature_names_.size() != static_cast<size_t>(dataset->num_total_features_)) {
    Log::Fatal("Size of feature name error, should be %d, got %d", dataset->num_total_features_,
733
               static_cast<int>(dataset->feature_names_.size()));
734
  }
Guolin Ke's avatar
Guolin Ke committed
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
  bool is_feature_order_by_group = true;
  int last_group = -1;
  int last_sub_feature = -1;
  // if features are ordered, not need to use hist_buf
  for (int i = 0; i < dataset->num_features_; ++i) {
    int group = dataset->feature2group_[i];
    int sub_feature = dataset->feature2subfeature_[i];
    if (group < last_group) {
      is_feature_order_by_group = false;
    } else if (group == last_group) {
      if (sub_feature <= last_sub_feature) {
        is_feature_order_by_group = false;
        break;
      }
    }
    last_group = group;
    last_sub_feature = sub_feature;
  }
  if (!is_feature_order_by_group) {
754
    Log::Fatal("Features in dataset should be ordered by group");
Guolin Ke's avatar
Guolin Ke committed
755
  }
Guolin Ke's avatar
Guolin Ke committed
756
757
758
}

std::vector<std::string> DatasetLoader::LoadTextDataToMemory(const char* filename, const Metadata& metadata,
759
760
                                                             int rank, int num_machines, int* num_global_data,
                                                             std::vector<data_size_t>* used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
761
  TextReader<data_size_t> text_reader(filename, config_.header);
Guolin Ke's avatar
Guolin Ke committed
762
  used_data_indices->clear();
Guolin Ke's avatar
Guolin Ke committed
763
  if (num_machines == 1 || config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
764
765
766
767
768
769
770
771
772
    // read all lines
    *num_global_data = text_reader.ReadAllLines();
  } else {  // need partition data
            // get query data
    const data_size_t* query_boundaries = metadata.query_boundaries();

    if (query_boundaries == nullptr) {
      // if not contain query data, minimal sample unit is one record
      *num_global_data = text_reader.ReadAndFilterLines([this, rank, num_machines](data_size_t) {
Guolin Ke's avatar
Guolin Ke committed
773
        if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
774
775
776
777
778
779
780
781
782
783
784
785
786
787
          return true;
        } else {
          return false;
        }
      }, used_data_indices);
    } else {
      // if contain query data, minimal sample unit is one query
      data_size_t num_queries = metadata.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      *num_global_data = text_reader.ReadAndFilterLines(
        [this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries]
      (data_size_t line_idx) {
        if (qid >= num_queries) {
788
789
          Log::Fatal("Current query exceeds the range of the query file,\n"
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
790
791
792
793
        }
        if (line_idx >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
794
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
795
796
797
798
799
800
801
802
803
804
805
806
            is_query_used = true;
          }
          ++qid;
        }
        return is_query_used;
      }, used_data_indices);
    }
  }
  return std::move(text_reader.Lines());
}

std::vector<std::string> DatasetLoader::SampleTextDataFromMemory(const std::vector<std::string>& data) {
Guolin Ke's avatar
Guolin Ke committed
807
  int sample_cnt = config_.bin_construct_sample_cnt;
808
809
  if (static_cast<size_t>(sample_cnt) > data.size()) {
    sample_cnt = static_cast<int>(data.size());
810
  }
811
  auto sample_indices = random_.Sample(static_cast<int>(data.size()), sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
812
  std::vector<std::string> out(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
813
814
  for (size_t i = 0; i < sample_indices.size(); ++i) {
    const size_t idx = sample_indices[i];
Guolin Ke's avatar
Guolin Ke committed
815
    out[i] = data[idx];
Guolin Ke's avatar
Guolin Ke committed
816
817
818
819
  }
  return out;
}

820
821
822
std::vector<std::string> DatasetLoader::SampleTextDataFromFile(const char* filename, const Metadata& metadata,
                                                               int rank, int num_machines, int* num_global_data,
                                                               std::vector<data_size_t>* used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
823
824
  const data_size_t sample_cnt = static_cast<data_size_t>(config_.bin_construct_sample_cnt);
  TextReader<data_size_t> text_reader(filename, config_.header);
Guolin Ke's avatar
Guolin Ke committed
825
  std::vector<std::string> out_data;
Guolin Ke's avatar
Guolin Ke committed
826
  if (num_machines == 1 || config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
827
    *num_global_data = static_cast<data_size_t>(text_reader.SampleFromFile(&random_, sample_cnt, &out_data));
Guolin Ke's avatar
Guolin Ke committed
828
829
830
831
832
833
834
  } else {  // need partition data
            // get query data
    const data_size_t* query_boundaries = metadata.query_boundaries();
    if (query_boundaries == nullptr) {
      // if not contain query file, minimal sample unit is one record
      *num_global_data = text_reader.SampleAndFilterFromFile([this, rank, num_machines]
      (data_size_t) {
Guolin Ke's avatar
Guolin Ke committed
835
        if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
836
837
838
839
          return true;
        } else {
          return false;
        }
Guolin Ke's avatar
Guolin Ke committed
840
      }, used_data_indices, &random_, sample_cnt, &out_data);
Guolin Ke's avatar
Guolin Ke committed
841
842
843
844
845
846
847
848
849
    } else {
      // if contain query file, minimal sample unit is one query
      data_size_t num_queries = metadata.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      *num_global_data = text_reader.SampleAndFilterFromFile(
        [this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries]
      (data_size_t line_idx) {
        if (qid >= num_queries) {
850
851
          Log::Fatal("Query id exceeds the range of the query file, "
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
852
853
854
855
        }
        if (line_idx >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
856
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
857
858
859
860
861
            is_query_used = true;
          }
          ++qid;
        }
        return is_query_used;
Guolin Ke's avatar
Guolin Ke committed
862
      }, used_data_indices, &random_, sample_cnt, &out_data);
Guolin Ke's avatar
Guolin Ke committed
863
864
865
866
867
    }
  }
  return out_data;
}

868
869
870
void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
                                                    const std::vector<std::string>& sample_data,
                                                    const Parser* parser, Dataset* dataset) {
Guolin Ke's avatar
Guolin Ke committed
871
  std::vector<std::vector<double>> sample_values;
Guolin Ke's avatar
Guolin Ke committed
872
  std::vector<std::vector<int>> sample_indices;
Guolin Ke's avatar
Guolin Ke committed
873
874
  std::vector<std::pair<int, double>> oneline_features;
  double label;
Guolin Ke's avatar
Guolin Ke committed
875
  for (int i = 0; i < static_cast<int>(sample_data.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
876
877
878
879
    oneline_features.clear();
    // parse features
    parser->ParseOneLine(sample_data[i].c_str(), &oneline_features, &label);
    for (std::pair<int, double>& inner_data : oneline_features) {
880
      if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
Guolin Ke's avatar
Guolin Ke committed
881
882
        sample_values.resize(inner_data.first + 1);
        sample_indices.resize(inner_data.first + 1);
883
      }
Guolin Ke's avatar
Guolin Ke committed
884
      if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
Guolin Ke's avatar
Guolin Ke committed
885
886
        sample_values[inner_data.first].emplace_back(inner_data.second);
        sample_indices[inner_data.first].emplace_back(i);
Guolin Ke's avatar
Guolin Ke committed
887
888
889
890
      }
    }
  }

Guolin Ke's avatar
Guolin Ke committed
891
  dataset->feature_groups_.clear();
892
893
894
895
896
897
  dataset->num_total_features_ = std::max(static_cast<int>(sample_values.size()), parser->NumFeatures());
  if (num_machines > 1) {
    dataset->num_total_features_ = Network::GlobalSyncUpByMax(dataset->num_total_features_);
  }
  if (!feature_names_.empty()) {
    CHECK(dataset->num_total_features_ == static_cast<int>(feature_names_.size()));
898
  }
Guolin Ke's avatar
Guolin Ke committed
899

Belinda Trotta's avatar
Belinda Trotta committed
900
901
902
903
904
  if (!config_.max_bin_by_feature.empty()) {
    CHECK(static_cast<size_t>(dataset->num_total_features_) == config_.max_bin_by_feature.size());
    CHECK(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())) > 1);
  }

905
906
  // get forced split
  std::string forced_bins_path = config_.forcedbins_filename;
907
908
  std::vector<std::vector<double>> forced_bin_bounds = DatasetLoader::GetForcedBins(forced_bins_path,
                                                                                    dataset->num_total_features_,
909
910
                                                                                    categorical_features_);

Guolin Ke's avatar
Guolin Ke committed
911
912
913
914
915
916
  // check the range of label_idx, weight_idx and group_idx
  CHECK(label_idx_ >= 0 && label_idx_ <= dataset->num_total_features_);
  CHECK(weight_idx_ < 0 || weight_idx_ < dataset->num_total_features_);
  CHECK(group_idx_ < 0 || group_idx_ < dataset->num_total_features_);

  // fill feature_names_ if not header
Guolin Ke's avatar
Guolin Ke committed
917
  if (feature_names_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
918
919
920
921
922
923
    for (int i = 0; i < dataset->num_total_features_; ++i) {
      std::stringstream str_buf;
      str_buf << "Column_" << i;
      feature_names_.push_back(str_buf.str());
    }
  }
924
  dataset->set_feature_names(feature_names_);
Guolin Ke's avatar
Guolin Ke committed
925
  std::vector<std::unique_ptr<BinMapper>> bin_mappers(dataset->num_total_features_);
Guolin Ke's avatar
Guolin Ke committed
926
  const data_size_t filter_cnt = static_cast<data_size_t>(
Guolin Ke's avatar
Guolin Ke committed
927
    static_cast<double>(config_.min_data_in_leaf* sample_data.size()) / dataset->num_data_);
Guolin Ke's avatar
Guolin Ke committed
928
929
930
  // start find bins
  if (num_machines == 1) {
    // if only one machine, find bin locally
931
    OMP_INIT_EX();
932
    #pragma omp parallel for schedule(guided)
Guolin Ke's avatar
Guolin Ke committed
933
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
934
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
935
      if (ignore_features_.count(i) > 0) {
Guolin Ke's avatar
Guolin Ke committed
936
        bin_mappers[i] = nullptr;
Guolin Ke's avatar
Guolin Ke committed
937
938
        continue;
      }
939
940
941
942
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(i)) {
        bin_type = BinType::CategoricalBin;
      }
Guolin Ke's avatar
Guolin Ke committed
943
      bin_mappers[i].reset(new BinMapper());
Belinda Trotta's avatar
Belinda Trotta committed
944
945
      if (config_.max_bin_by_feature.empty()) {
        bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
946
                                sample_data.size(), config_.max_bin, config_.min_data_in_bin,
947
948
                                filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing,
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
949
950
      } else {
        bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
951
952
                                sample_data.size(), config_.max_bin_by_feature[i],
                                config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing,
953
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
954
      }
955
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
956
    }
957
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
958
959
  } else {
    // start and len will store the process feature indices for different machines
960
    // machine i will find bins for features in [ start[i], start[i] + len[i] )
Guolin Ke's avatar
Guolin Ke committed
961
962
    std::vector<int> start(num_machines);
    std::vector<int> len(num_machines);
963
    int step = (dataset->num_total_features_ + num_machines - 1) / num_machines;
Guolin Ke's avatar
Guolin Ke committed
964
965
966
967
    if (step < 1) { step = 1; }

    start[0] = 0;
    for (int i = 0; i < num_machines - 1; ++i) {
968
      len[i] = std::min(step, dataset->num_total_features_ - start[i]);
Guolin Ke's avatar
Guolin Ke committed
969
970
      start[i + 1] = start[i] + len[i];
    }
971
    len[num_machines - 1] = dataset->num_total_features_ - start[num_machines - 1];
972
    OMP_INIT_EX();
973
    #pragma omp parallel for schedule(guided)
974
    for (int i = 0; i < len[rank]; ++i) {
975
      OMP_LOOP_EX_BEGIN();
976
977
978
979
980
981
982
983
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(start[rank] + i)) {
        bin_type = BinType::CategoricalBin;
      }
      bin_mappers[i].reset(new BinMapper());
984
985
986
      if (sample_values.size() <= start[rank] + i) {
        continue;
      }
Belinda Trotta's avatar
Belinda Trotta committed
987
      if (config_.max_bin_by_feature.empty()) {
988
        bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
Belinda Trotta's avatar
Belinda Trotta committed
989
                                static_cast<int>(sample_values[start[rank] + i].size()),
990
                                sample_data.size(), config_.max_bin, config_.min_data_in_bin,
991
                                filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing,
992
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
993
      } else {
994
        bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
Belinda Trotta's avatar
Belinda Trotta committed
995
                                static_cast<int>(sample_values[start[rank] + i].size()),
996
997
                                sample_data.size(), config_.max_bin_by_feature[i],
                                config_.min_data_in_bin, filter_cnt, bin_type,
998
                                config_.use_missing, config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
999
      }
1000
      OMP_LOOP_EX_END();
1001
    }
1002
    OMP_THROW_EX();
1003
    comm_size_t self_buf_size = 0;
Guolin Ke's avatar
Guolin Ke committed
1004
    for (int i = 0; i < len[rank]; ++i) {
1005
1006
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
Guolin Ke's avatar
Guolin Ke committed
1007
      }
1008
      self_buf_size += static_cast<comm_size_t>(bin_mappers[i]->SizesInByte());
Guolin Ke's avatar
Guolin Ke committed
1009
    }
1010
1011
    std::vector<char> input_buffer(self_buf_size);
    auto cp_ptr = input_buffer.data();
Guolin Ke's avatar
Guolin Ke committed
1012
    for (int i = 0; i < len[rank]; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1013
1014
1015
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
1016
1017
      bin_mappers[i]->CopyTo(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
1018
1019
      // free
      bin_mappers[i].reset(nullptr);
Guolin Ke's avatar
Guolin Ke committed
1020
    }
1021
1022
1023
1024
    std::vector<comm_size_t> size_len = Network::GlobalArray(self_buf_size);
    std::vector<comm_size_t> size_start(num_machines, 0);
    for (int i = 1; i < num_machines; ++i) {
      size_start[i] = size_start[i - 1] + size_len[i - 1];
Guolin Ke's avatar
Guolin Ke committed
1025
    }
1026
1027
    comm_size_t total_buffer_size = size_start[num_machines - 1] + size_len[num_machines - 1];
    std::vector<char> output_buffer(total_buffer_size);
Guolin Ke's avatar
Guolin Ke committed
1028
    // gather global feature bin mappers
1029
1030
    Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), total_buffer_size);
    cp_ptr = output_buffer.data();
Guolin Ke's avatar
Guolin Ke committed
1031
    // restore features bins from buffer
1032
    for (int i = 0; i < dataset->num_total_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1033
      if (ignore_features_.count(i) > 0) {
Guolin Ke's avatar
Guolin Ke committed
1034
        bin_mappers[i] = nullptr;
Guolin Ke's avatar
Guolin Ke committed
1035
1036
        continue;
      }
Guolin Ke's avatar
Guolin Ke committed
1037
      bin_mappers[i].reset(new BinMapper());
1038
1039
      bin_mappers[i]->CopyFrom(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
Guolin Ke's avatar
Guolin Ke committed
1040
1041
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1042
  sample_values.clear();
1043
  dataset->Construct(&bin_mappers, dataset->num_total_features_, forced_bin_bounds, Common::Vector2Ptr<int>(&sample_indices).data(),
1044
                     Common::VectorSize<int>(sample_indices).data(), static_cast<int>(sample_indices.size()), sample_data.size(), config_);
Guolin Ke's avatar
Guolin Ke committed
1045
1046
1047
}

/*! \brief Extract local features from memory */
Guolin Ke's avatar
Guolin Ke committed
1048
void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>* text_data, const Parser* parser, Dataset* dataset) {
Guolin Ke's avatar
Guolin Ke committed
1049
1050
1051
  std::vector<std::pair<int, double>> oneline_features;
  double tmp_label = 0.0f;
  if (predict_fun_ == nullptr) {
1052
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1053
    // if doesn't need to prediction with initial model
1054
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
Guolin Ke's avatar
Guolin Ke committed
1055
    for (data_size_t i = 0; i < dataset->num_data_; ++i) {
1056
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1057
1058
1059
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
Guolin Ke's avatar
Guolin Ke committed
1060
      parser->ParseOneLine(text_data->at(i).c_str(), &oneline_features, &tmp_label);
Guolin Ke's avatar
Guolin Ke committed
1061
      // set label
1062
      dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1063
      // free processed line:
Guolin Ke's avatar
Guolin Ke committed
1064
      text_data->at(i).clear();
Guolin Ke's avatar
Guolin Ke committed
1065
1066
1067
1068
      // shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
      // text_reader_->Lines()[i].shrink_to_fit();
      // push data
      for (auto& inner_data : oneline_features) {
1069
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1070
1071
1072
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1073
1074
1075
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
          dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
Guolin Ke's avatar
Guolin Ke committed
1076
1077
        } else {
          if (inner_data.first == weight_idx_) {
1078
            dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1079
1080
1081
1082
1083
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
1084
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1085
    }
1086
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1087
  } else {
1088
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1089
    // if need to prediction with initial model
1090
    std::vector<double> init_score(dataset->num_data_ * num_class_);
1091
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
Guolin Ke's avatar
Guolin Ke committed
1092
    for (data_size_t i = 0; i < dataset->num_data_; ++i) {
1093
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1094
1095
1096
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
Guolin Ke's avatar
Guolin Ke committed
1097
      parser->ParseOneLine(text_data->at(i).c_str(), &oneline_features, &tmp_label);
Guolin Ke's avatar
Guolin Ke committed
1098
      // set initial score
Guolin Ke's avatar
Guolin Ke committed
1099
1100
      std::vector<double> oneline_init_score(num_class_);
      predict_fun_(oneline_features, oneline_init_score.data());
1101
      for (int k = 0; k < num_class_; ++k) {
1102
        init_score[k * dataset->num_data_ + i] = static_cast<double>(oneline_init_score[k]);
Guolin Ke's avatar
Guolin Ke committed
1103
1104
      }
      // set label
1105
      dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1106
1107
1108
1109
1110
1111
      // free processed line:
      text_data[i].clear();
      // shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
      // text_reader_->Lines()[i].shrink_to_fit();
      // push data
      for (auto& inner_data : oneline_features) {
1112
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1113
1114
1115
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1116
1117
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
1118
          dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
Guolin Ke's avatar
Guolin Ke committed
1119
1120
        } else {
          if (inner_data.first == weight_idx_) {
1121
            dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1122
1123
1124
1125
1126
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
1127
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1128
    }
1129
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1130
    // metadata_ will manage space of init_score
1131
    dataset->metadata_.SetInitScore(init_score.data(), dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1132
  }
Guolin Ke's avatar
Guolin Ke committed
1133
  dataset->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1134
  // text data can be free after loaded feature values
Guolin Ke's avatar
Guolin Ke committed
1135
  text_data->clear();
Guolin Ke's avatar
Guolin Ke committed
1136
1137
1138
}

/*! \brief Extract local features from file */
1139
1140
void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser* parser,
                                            const std::vector<data_size_t>& used_data_indices, Dataset* dataset) {
1141
  std::vector<double> init_score;
Guolin Ke's avatar
Guolin Ke committed
1142
  if (predict_fun_ != nullptr) {
1143
    init_score = std::vector<double>(dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1144
1145
1146
1147
1148
1149
  }
  std::function<void(data_size_t, const std::vector<std::string>&)> process_fun =
    [this, &init_score, &parser, &dataset]
  (data_size_t start_idx, const std::vector<std::string>& lines) {
    std::vector<std::pair<int, double>> oneline_features;
    double tmp_label = 0.0f;
1150
    OMP_INIT_EX();
1151
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
Guolin Ke's avatar
Guolin Ke committed
1152
    for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) {
1153
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1154
1155
1156
1157
1158
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
      parser->ParseOneLine(lines[i].c_str(), &oneline_features, &tmp_label);
      // set initial score
Guolin Ke's avatar
Guolin Ke committed
1159
      if (!init_score.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1160
1161
        std::vector<double> oneline_init_score(num_class_);
        predict_fun_(oneline_features, oneline_init_score.data());
1162
        for (int k = 0; k < num_class_; ++k) {
1163
          init_score[k * dataset->num_data_ + start_idx + i] = static_cast<double>(oneline_init_score[k]);
Guolin Ke's avatar
Guolin Ke committed
1164
1165
1166
        }
      }
      // set label
1167
      dataset->metadata_.SetLabelAt(start_idx + i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1168
1169
      // push data
      for (auto& inner_data : oneline_features) {
1170
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1171
1172
1173
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1174
1175
1176
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
          dataset->feature_groups_[group]->PushData(tid, sub_feature, start_idx + i, inner_data.second);
Guolin Ke's avatar
Guolin Ke committed
1177
1178
        } else {
          if (inner_data.first == weight_idx_) {
1179
            dataset->metadata_.SetWeightAt(start_idx + i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1180
1181
1182
1183
1184
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(start_idx + i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
1185
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1186
    }
1187
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1188
  };
Guolin Ke's avatar
Guolin Ke committed
1189
  TextReader<data_size_t> text_reader(filename, config_.header);
Guolin Ke's avatar
Guolin Ke committed
1190
  if (!used_data_indices.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1191
1192
1193
1194
1195
1196
1197
1198
    // only need part of data
    text_reader.ReadPartAndProcessParallel(used_data_indices, process_fun);
  } else {
    // need full data
    text_reader.ReadAllAndProcessParallel(process_fun);
  }

  // metadata_ will manage space of init_score
Guolin Ke's avatar
Guolin Ke committed
1199
  if (!init_score.empty()) {
1200
    dataset->metadata_.SetInitScore(init_score.data(), dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1201
  }
Guolin Ke's avatar
Guolin Ke committed
1202
  dataset->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1203
1204
1205
}

/*! \brief Check can load from binary file */
1206
std::string DatasetLoader::CheckCanLoadFromBin(const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
1207
1208
1209
  std::string bin_filename(filename);
  bin_filename.append(".bin");

1210
  auto reader = VirtualFileReader::Make(bin_filename.c_str());
Guolin Ke's avatar
Guolin Ke committed
1211

1212
  if (!reader->Init()) {
1213
    bin_filename = std::string(filename);
1214
1215
    reader = VirtualFileReader::Make(bin_filename.c_str());
    if (!reader->Init()) {
1216
      Log::Fatal("Cannot open data file %s", bin_filename.c_str());
1217
    }
1218
  }
1219
1220
1221
1222
1223

  size_t buffer_size = 256;
  auto buffer = std::vector<char>(buffer_size);
  // read size of token
  size_t size_of_token = std::strlen(Dataset::binary_file_token);
1224
  size_t read_cnt = reader->Read(buffer.data(), size_of_token);
1225
1226
  if (read_cnt == size_of_token
      && std::string(buffer.data()) == std::string(Dataset::binary_file_token)) {
1227
    return bin_filename;
Guolin Ke's avatar
Guolin Ke committed
1228
  } else {
1229
    return std::string();
Guolin Ke's avatar
Guolin Ke committed
1230
1231
1232
  }
}

1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270


std::vector<std::vector<double>> DatasetLoader::GetForcedBins(std::string forced_bins_path, int num_total_features,
                                                              const std::unordered_set<int>& categorical_features) {
  std::vector<std::vector<double>> forced_bins(num_total_features, std::vector<double>());
  if (forced_bins_path != "") {
    std::ifstream forced_bins_stream(forced_bins_path.c_str());
    if (forced_bins_stream.fail()) {
      Log::Warning("Could not open %s. Will ignore.", forced_bins_path.c_str());
    } else {
      std::stringstream buffer;
      buffer << forced_bins_stream.rdbuf();
      std::string err;
      Json forced_bins_json = Json::parse(buffer.str(), err);
      CHECK(forced_bins_json.is_array());
      std::vector<Json> forced_bins_arr = forced_bins_json.array_items();
      for (size_t i = 0; i < forced_bins_arr.size(); ++i) {
        int feature_num = forced_bins_arr[i]["feature"].int_value();
        CHECK(feature_num < num_total_features);
        if (categorical_features.count(feature_num)) {
          Log::Warning("Feature %d is categorical. Will ignore forced bins for this  feature.", feature_num);
        } else {
          std::vector<Json> bounds_arr = forced_bins_arr[i]["bin_upper_bound"].array_items();
          for (size_t j = 0; j < bounds_arr.size(); ++j) {
            forced_bins[feature_num].push_back(bounds_arr[j].number_value());
          }
        }
      }
      // remove duplicates
      for (int i = 0; i < num_total_features; ++i) {
        auto new_end = std::unique(forced_bins[i].begin(), forced_bins[i].end());
        forced_bins[i].erase(new_end, forced_bins[i].end());
      }
    }
  }
  return forced_bins;
}

1271
}  // namespace LightGBM