dataset_loader.cpp 53.1 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
#include <LightGBM/dataset_loader.h>

Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/network.h>
8
9
10
#include <LightGBM/utils/array_args.h>
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
11

12
13
#include <fstream>

14
15
#include <LightGBM/json11.hpp>

16
17
using namespace json11;

Guolin Ke's avatar
Guolin Ke committed
18
19
namespace LightGBM {

Guolin Ke's avatar
Guolin Ke committed
20
21
DatasetLoader::DatasetLoader(const Config& io_config, const PredictFunction& predict_fun, int num_class, const char* filename)
  :config_(io_config), random_(config_.data_random_seed), predict_fun_(predict_fun), num_class_(num_class) {
Guolin Ke's avatar
Guolin Ke committed
22
23
24
25
  label_idx_ = 0;
  weight_idx_ = NO_SPECIFIC;
  group_idx_ = NO_SPECIFIC;
  SetHeader(filename);
Guolin Ke's avatar
Guolin Ke committed
26
27
28
29
30
}

DatasetLoader::~DatasetLoader() {
}

Guolin Ke's avatar
Guolin Ke committed
31
void DatasetLoader::SetHeader(const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
32
  std::unordered_map<std::string, int> name2idx;
Guolin Ke's avatar
Guolin Ke committed
33
34
  std::string name_prefix("name:");
  if (filename != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
35
    TextReader<data_size_t> text_reader(filename, config_.header);
Guolin Ke's avatar
Guolin Ke committed
36

Guolin Ke's avatar
Guolin Ke committed
37
    // get column names
Guolin Ke's avatar
Guolin Ke committed
38
    if (config_.header) {
Guolin Ke's avatar
Guolin Ke committed
39
      std::string first_line = text_reader.first_line();
40
      feature_names_ = Common::Split(first_line.c_str(), "\t,");
Guolin Ke's avatar
Guolin Ke committed
41
42
    }

Guolin Ke's avatar
Guolin Ke committed
43
    // load label idx first
Guolin Ke's avatar
Guolin Ke committed
44
45
46
    if (config_.label_column.size() > 0) {
      if (Common::StartsWith(config_.label_column, name_prefix)) {
        std::string name = config_.label_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
47
48
49
50
51
52
53
54
55
56
        label_idx_ = -1;
        for (int i = 0; i < static_cast<int>(feature_names_.size()); ++i) {
          if (name == feature_names_[i]) {
            label_idx_ = i;
            break;
          }
        }
        if (label_idx_ >= 0) {
          Log::Info("Using column %s as label", name.c_str());
        } else {
57
58
          Log::Fatal("Could not find label column %s in data file \n"
                     "or data file doesn't contain header", name.c_str());
Guolin Ke's avatar
Guolin Ke committed
59
        }
Guolin Ke's avatar
Guolin Ke committed
60
      } else {
Guolin Ke's avatar
Guolin Ke committed
61
        if (!Common::AtoiAndCheck(config_.label_column.c_str(), &label_idx_)) {
62
63
64
          Log::Fatal("label_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
65
66
        }
        Log::Info("Using column number %d as label", label_idx_);
Guolin Ke's avatar
Guolin Ke committed
67
68
      }
    }
Guolin Ke's avatar
Guolin Ke committed
69

Guolin Ke's avatar
Guolin Ke committed
70
    if (!feature_names_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
71
72
73
74
      // erase label column name
      feature_names_.erase(feature_names_.begin() + label_idx_);
      for (size_t i = 0; i < feature_names_.size(); ++i) {
        name2idx[feature_names_[i]] = static_cast<int>(i);
Guolin Ke's avatar
Guolin Ke committed
75
      }
Guolin Ke's avatar
Guolin Ke committed
76
77
78
    }

    // load ignore columns
Guolin Ke's avatar
Guolin Ke committed
79
80
81
    if (config_.ignore_column.size() > 0) {
      if (Common::StartsWith(config_.ignore_column, name_prefix)) {
        std::string names = config_.ignore_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
82
83
84
85
86
87
88
89
90
        for (auto name : Common::Split(names.c_str(), ',')) {
          if (name2idx.count(name) > 0) {
            int tmp = name2idx[name];
            ignore_features_.emplace(tmp);
          } else {
            Log::Fatal("Could not find ignore column %s in data file", name.c_str());
          }
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
91
        for (auto token : Common::Split(config_.ignore_column.c_str(), ',')) {
Guolin Ke's avatar
Guolin Ke committed
92
93
          int tmp = 0;
          if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
94
95
96
            Log::Fatal("ignore_column is not a number,\n"
                       "if you want to use a column name,\n"
                       "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
97
98
          }
          ignore_features_.emplace(tmp);
Guolin Ke's avatar
Guolin Ke committed
99
100
101
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
102
    // load weight idx
Guolin Ke's avatar
Guolin Ke committed
103
104
105
    if (config_.weight_column.size() > 0) {
      if (Common::StartsWith(config_.weight_column, name_prefix)) {
        std::string name = config_.weight_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
106
107
108
109
110
111
        if (name2idx.count(name) > 0) {
          weight_idx_ = name2idx[name];
          Log::Info("Using column %s as weight", name.c_str());
        } else {
          Log::Fatal("Could not find weight column %s in data file", name.c_str());
        }
Guolin Ke's avatar
Guolin Ke committed
112
      } else {
Guolin Ke's avatar
Guolin Ke committed
113
        if (!Common::AtoiAndCheck(config_.weight_column.c_str(), &weight_idx_)) {
114
115
116
          Log::Fatal("weight_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
117
118
        }
        Log::Info("Using column number %d as weight", weight_idx_);
Guolin Ke's avatar
Guolin Ke committed
119
      }
Guolin Ke's avatar
Guolin Ke committed
120
      ignore_features_.emplace(weight_idx_);
Guolin Ke's avatar
Guolin Ke committed
121
    }
Guolin Ke's avatar
Guolin Ke committed
122
    // load group idx
Guolin Ke's avatar
Guolin Ke committed
123
124
125
    if (config_.group_column.size() > 0) {
      if (Common::StartsWith(config_.group_column, name_prefix)) {
        std::string name = config_.group_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
126
127
128
129
130
131
132
        if (name2idx.count(name) > 0) {
          group_idx_ = name2idx[name];
          Log::Info("Using column %s as group/query id", name.c_str());
        } else {
          Log::Fatal("Could not find group/query column %s in data file", name.c_str());
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
133
        if (!Common::AtoiAndCheck(config_.group_column.c_str(), &group_idx_)) {
134
135
136
          Log::Fatal("group_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
137
138
139
140
        }
        Log::Info("Using column number %d as group/query id", group_idx_);
      }
      ignore_features_.emplace(group_idx_);
Guolin Ke's avatar
Guolin Ke committed
141
142
    }
  }
Guolin Ke's avatar
Guolin Ke committed
143
144
145
  if (config_.categorical_feature.size() > 0) {
    if (Common::StartsWith(config_.categorical_feature, name_prefix)) {
      std::string names = config_.categorical_feature.substr(name_prefix.size());
146
147
148
149
150
      for (auto name : Common::Split(names.c_str(), ',')) {
        if (name2idx.count(name) > 0) {
          int tmp = name2idx[name];
          categorical_features_.emplace(tmp);
        } else {
Guolin Ke's avatar
Guolin Ke committed
151
          Log::Fatal("Could not find categorical_feature %s in data file", name.c_str());
152
153
154
        }
      }
    } else {
Guolin Ke's avatar
Guolin Ke committed
155
      for (auto token : Common::Split(config_.categorical_feature.c_str(), ',')) {
156
157
        int tmp = 0;
        if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
Guolin Ke's avatar
Guolin Ke committed
158
          Log::Fatal("categorical_feature is not a number,\n"
159
160
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
161
162
163
164
165
        }
        categorical_features_.emplace(tmp);
      }
    }
  }
Guolin Ke's avatar
Guolin Ke committed
166
167
}

168
Dataset* DatasetLoader::LoadFromFile(const char* filename, const char* initscore_file, int rank, int num_machines) {
Guolin Ke's avatar
Guolin Ke committed
169
  // don't support query id in data file when training in parallel
Guolin Ke's avatar
Guolin Ke committed
170
  if (num_machines > 1 && !config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
171
    if (group_idx_ > 0) {
172
173
      Log::Fatal("Using a query id without pre-partitioning the data file is not supported for parallel training.\n"
                 "Please use an additional query file or pre-partition the data");
Guolin Ke's avatar
Guolin Ke committed
174
175
    }
  }
Guolin Ke's avatar
Guolin Ke committed
176
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
Guolin Ke's avatar
Guolin Ke committed
177
178
  data_size_t num_global_data = 0;
  std::vector<data_size_t> used_data_indices;
179
180
  auto bin_filename = CheckCanLoadFromBin(filename);
  if (bin_filename.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
181
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_));
Guolin Ke's avatar
Guolin Ke committed
182
183
184
185
    if (parser == nullptr) {
      Log::Fatal("Could not recognize data format of %s", filename);
    }
    dataset->data_filename_ = filename;
Guolin Ke's avatar
Guolin Ke committed
186
    dataset->label_idx_ = label_idx_;
187
    dataset->metadata_.Init(filename, initscore_file);
Guolin Ke's avatar
Guolin Ke committed
188
    if (!config_.two_round) {
Guolin Ke's avatar
Guolin Ke committed
189
      // read data to memory
190
      auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
191
192
193
194
      dataset->num_data_ = static_cast<data_size_t>(text_data.size());
      // sample data
      auto sample_data = SampleTextDataFromMemory(text_data);
      // construct feature bin mappers
Guolin Ke's avatar
Guolin Ke committed
195
      ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
196
      // initialize label
197
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
198
      // extract features
Guolin Ke's avatar
Guolin Ke committed
199
      ExtractFeaturesFromMemory(&text_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
200
201
202
203
204
205
206
207
208
209
      text_data.clear();
    } else {
      // sample data from file
      auto sample_data = SampleTextDataFromFile(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
      if (used_data_indices.size() > 0) {
        dataset->num_data_ = static_cast<data_size_t>(used_data_indices.size());
      } else {
        dataset->num_data_ = num_global_data;
      }
      // construct feature bin mappers
Guolin Ke's avatar
Guolin Ke committed
210
      ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
211
      // initialize label
212
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
213
      Log::Debug("Making second pass...");
Guolin Ke's avatar
Guolin Ke committed
214
      // extract features
Guolin Ke's avatar
Guolin Ke committed
215
      ExtractFeaturesFromFile(filename, parser.get(), used_data_indices, dataset.get());
Guolin Ke's avatar
Guolin Ke committed
216
217
218
    }
  } else {
    // load data from binary file
219
    dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), rank, num_machines, &num_global_data, &used_data_indices));
Guolin Ke's avatar
Guolin Ke committed
220
221
222
223
  }
  // check meta data
  dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);
  // need to check training data
Guolin Ke's avatar
Guolin Ke committed
224
225
  CheckDataset(dataset.get());
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
226
227
228
229
}



230
Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, const char* initscore_file, const Dataset* train_data) {
Guolin Ke's avatar
Guolin Ke committed
231
232
  data_size_t num_global_data = 0;
  std::vector<data_size_t> used_data_indices;
Guolin Ke's avatar
Guolin Ke committed
233
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
234
235
  auto bin_filename = CheckCanLoadFromBin(filename);
  if (bin_filename.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
236
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_));
Guolin Ke's avatar
Guolin Ke committed
237
238
239
240
    if (parser == nullptr) {
      Log::Fatal("Could not recognize data format of %s", filename);
    }
    dataset->data_filename_ = filename;
Guolin Ke's avatar
Guolin Ke committed
241
    dataset->label_idx_ = label_idx_;
242
    dataset->metadata_.Init(filename, initscore_file);
Guolin Ke's avatar
Guolin Ke committed
243
    if (!config_.two_round) {
Guolin Ke's avatar
Guolin Ke committed
244
245
246
247
      // read data in memory
      auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, 0, 1, &num_global_data, &used_data_indices);
      dataset->num_data_ = static_cast<data_size_t>(text_data.size());
      // initialize label
248
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
249
      dataset->CreateValid(train_data);
Guolin Ke's avatar
Guolin Ke committed
250
      // extract features
Guolin Ke's avatar
Guolin Ke committed
251
      ExtractFeaturesFromMemory(&text_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
252
253
      text_data.clear();
    } else {
Guolin Ke's avatar
Guolin Ke committed
254
      TextReader<data_size_t> text_reader(filename, config_.header);
Guolin Ke's avatar
Guolin Ke committed
255
256
257
258
      // Get number of lines of data file
      dataset->num_data_ = static_cast<data_size_t>(text_reader.CountLine());
      num_global_data = dataset->num_data_;
      // initialize label
259
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
260
      dataset->CreateValid(train_data);
Guolin Ke's avatar
Guolin Ke committed
261
      // extract features
Guolin Ke's avatar
Guolin Ke committed
262
      ExtractFeaturesFromFile(filename, parser.get(), used_data_indices, dataset.get());
Guolin Ke's avatar
Guolin Ke committed
263
264
265
    }
  } else {
    // load data from binary file
266
    dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), 0, 1, &num_global_data, &used_data_indices));
Guolin Ke's avatar
Guolin Ke committed
267
268
269
270
  }
  // not need to check validation data
  // check meta data
  dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
271
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
272
273
}

274
275
276
Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* bin_filename,
                                        int rank, int num_machines, int* num_global_data,
                                        std::vector<data_size_t>* used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
277
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
278
  auto reader = VirtualFileReader::Make(bin_filename);
Guolin Ke's avatar
Guolin Ke committed
279
  dataset->data_filename_ = data_filename;
280
  if (!reader->Init()) {
Guolin Ke's avatar
Guolin Ke committed
281
282
283
284
285
    Log::Fatal("Could not read binary data from %s", bin_filename);
  }

  // buffer to read binary file
  size_t buffer_size = 16 * 1024 * 1024;
Guolin Ke's avatar
Guolin Ke committed
286
  auto buffer = std::vector<char>(buffer_size);
287

288
289
  // check token
  size_t size_of_token = std::strlen(Dataset::binary_file_token);
290
291
  size_t read_cnt = reader->Read(buffer.data(), sizeof(char) * size_of_token);
  if (read_cnt != sizeof(char) * size_of_token) {
292
293
294
    Log::Fatal("Binary file error: token has the wrong size");
  }
  if (std::string(buffer.data()) != std::string(Dataset::binary_file_token)) {
295
    Log::Fatal("Input file is not LightGBM binary file");
296
  }
Guolin Ke's avatar
Guolin Ke committed
297
298

  // read size of header
299
  read_cnt = reader->Read(buffer.data(), sizeof(size_t));
Guolin Ke's avatar
Guolin Ke committed
300

301
  if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
302
303
304
    Log::Fatal("Binary file error: header has the wrong size");
  }

Guolin Ke's avatar
Guolin Ke committed
305
  size_t size_of_head = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
306
307
308
309

  // re-allocmate space if not enough
  if (size_of_head > buffer_size) {
    buffer_size = size_of_head;
Guolin Ke's avatar
Guolin Ke committed
310
    buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
311
312
  }
  // read header
313
  read_cnt = reader->Read(buffer.data(), size_of_head);
Guolin Ke's avatar
Guolin Ke committed
314
315
316
317
318

  if (read_cnt != size_of_head) {
    Log::Fatal("Binary file error: header is incorrect");
  }
  // get header
Guolin Ke's avatar
Guolin Ke committed
319
  const char* mem_ptr = buffer.data();
Guolin Ke's avatar
Guolin Ke committed
320
321
322
323
324
325
  dataset->num_data_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
  mem_ptr += sizeof(dataset->num_data_);
  dataset->num_features_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->num_features_);
  dataset->num_total_features_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->num_total_features_);
Guolin Ke's avatar
Guolin Ke committed
326
327
  dataset->label_idx_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->label_idx_);
328
329
330
331
332
333
334
335
336
337
  dataset->max_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->max_bin_);
  dataset->bin_construct_sample_cnt_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->bin_construct_sample_cnt_);
  dataset->min_data_in_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->min_data_in_bin_);
  dataset->use_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
  mem_ptr += sizeof(dataset->use_missing_);
  dataset->zero_as_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
  mem_ptr += sizeof(dataset->zero_as_missing_);
Guolin Ke's avatar
Guolin Ke committed
338
339
  dataset->sparse_threshold_ = *(reinterpret_cast<const double*>(mem_ptr));
  mem_ptr += sizeof(dataset->sparse_threshold_);
Guolin Ke's avatar
Guolin Ke committed
340
341
  const int* tmp_feature_map = reinterpret_cast<const int*>(mem_ptr);
  dataset->used_feature_map_.clear();
Guolin Ke's avatar
Guolin Ke committed
342
  for (int i = 0; i < dataset->num_total_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
343
344
    dataset->used_feature_map_.push_back(tmp_feature_map[i]);
  }
Guolin Ke's avatar
Guolin Ke committed
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
  mem_ptr += sizeof(int) * dataset->num_total_features_;
  // num_groups
  dataset->num_groups_ = *(reinterpret_cast<const int*>(mem_ptr));
  mem_ptr += sizeof(dataset->num_groups_);
  // real_feature_idx_
  const int* tmp_ptr_real_feature_idx_ = reinterpret_cast<const int*>(mem_ptr);
  dataset->real_feature_idx_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->real_feature_idx_.push_back(tmp_ptr_real_feature_idx_[i]);
  }
  mem_ptr += sizeof(int) * dataset->num_features_;
  // feature2group
  const int* tmp_ptr_feature2group = reinterpret_cast<const int*>(mem_ptr);
  dataset->feature2group_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->feature2group_.push_back(tmp_ptr_feature2group[i]);
  }
  mem_ptr += sizeof(int) * dataset->num_features_;
  // feature2subfeature
  const int* tmp_ptr_feature2subfeature = reinterpret_cast<const int*>(mem_ptr);
  dataset->feature2subfeature_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->feature2subfeature_.push_back(tmp_ptr_feature2subfeature[i]);
  }
  mem_ptr += sizeof(int) * dataset->num_features_;
  // group_bin_boundaries
  const uint64_t* tmp_ptr_group_bin_boundaries = reinterpret_cast<const uint64_t*>(mem_ptr);
  dataset->group_bin_boundaries_.clear();
  for (int i = 0; i < dataset->num_groups_ + 1; ++i) {
    dataset->group_bin_boundaries_.push_back(tmp_ptr_group_bin_boundaries[i]);
  }
  mem_ptr += sizeof(uint64_t) * (dataset->num_groups_ + 1);

  // group_feature_start_
  const int* tmp_ptr_group_feature_start = reinterpret_cast<const int*>(mem_ptr);
  dataset->group_feature_start_.clear();
381
  for (int i = 0; i < dataset->num_groups_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
382
383
384
385
386
387
388
389
390
391
392
393
    dataset->group_feature_start_.push_back(tmp_ptr_group_feature_start[i]);
  }
  mem_ptr += sizeof(int) * (dataset->num_groups_);

  // group_feature_cnt_
  const int* tmp_ptr_group_feature_cnt = reinterpret_cast<const int*>(mem_ptr);
  dataset->group_feature_cnt_.clear();
  for (int i = 0; i < dataset->num_groups_; ++i) {
    dataset->group_feature_cnt_.push_back(tmp_ptr_group_feature_cnt[i]);
  }
  mem_ptr += sizeof(int) * (dataset->num_groups_);

394
  if (!config_.monotone_constraints.empty()) {
395
    CHECK(static_cast<size_t>(dataset->num_total_features_) == config_.monotone_constraints.size());
396
    dataset->monotone_types_.resize(dataset->num_features_);
397
    for (int i = 0; i < dataset->num_total_features_; ++i) {
398
      int inner_fidx = dataset->InnerFeatureIndex(i);
399
      if (inner_fidx >= 0) {
400
401
402
        dataset->monotone_types_[inner_fidx] = config_.monotone_constraints[i];
      }
    }
403
  } else {
404
405
406
407
408
    const int8_t* tmp_ptr_monotone_type = reinterpret_cast<const int8_t*>(mem_ptr);
    dataset->monotone_types_.clear();
    for (int i = 0; i < dataset->num_features_; ++i) {
      dataset->monotone_types_.push_back(tmp_ptr_monotone_type[i]);
    }
Guolin Ke's avatar
Guolin Ke committed
409
410
411
412
413
414
415
  }
  mem_ptr += sizeof(int8_t) * (dataset->num_features_);

  if (ArrayArgs<int8_t>::CheckAllZero(dataset->monotone_types_)) {
    dataset->monotone_types_.clear();
  }

416
  if (!config_.feature_contri.empty()) {
417
    CHECK(static_cast<size_t>(dataset->num_total_features_) == config_.feature_contri.size());
418
    dataset->feature_penalty_.resize(dataset->num_features_);
419
    for (int i = 0; i < dataset->num_total_features_; ++i) {
420
      int inner_fidx = dataset->InnerFeatureIndex(i);
421
      if (inner_fidx >= 0) {
422
423
424
        dataset->feature_penalty_[inner_fidx] = config_.feature_contri[i];
      }
    }
425
  } else {
426
427
428
429
430
    const double* tmp_ptr_feature_penalty = reinterpret_cast<const double*>(mem_ptr);
    dataset->feature_penalty_.clear();
    for (int i = 0; i < dataset->num_features_; ++i) {
      dataset->feature_penalty_.push_back(tmp_ptr_feature_penalty[i]);
    }
Guolin Ke's avatar
Guolin Ke committed
431
432
433
434
435
436
437
  }
  mem_ptr += sizeof(double) * (dataset->num_features_);

  if (ArrayArgs<double>::CheckAll(dataset->feature_penalty_, 1)) {
    dataset->feature_penalty_.clear();
  }

Belinda Trotta's avatar
Belinda Trotta committed
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
  if (!config_.max_bin_by_feature.empty()) {
    CHECK(static_cast<size_t>(dataset->num_total_features_) == config_.max_bin_by_feature.size());
    CHECK(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())) > 1);
    dataset->max_bin_by_feature_.resize(dataset->num_total_features_);
    dataset->max_bin_by_feature_.assign(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end());
  } else {
    const int32_t* tmp_ptr_max_bin_by_feature = reinterpret_cast<const int32_t*>(mem_ptr);
    dataset->max_bin_by_feature_.clear();
    for (int i = 0; i < dataset->num_total_features_; ++i) {
      dataset->max_bin_by_feature_.push_back(tmp_ptr_max_bin_by_feature[i]);
    }
  }
  mem_ptr += sizeof(int32_t) * (dataset->num_total_features_);
  if (ArrayArgs<int32_t>::CheckAll(dataset->max_bin_by_feature_, -1)) {
    dataset->max_bin_by_feature_.clear();
  }

Guolin Ke's avatar
Guolin Ke committed
455
  // get feature names
Guolin Ke's avatar
Guolin Ke committed
456
  dataset->feature_names_.clear();
Guolin Ke's avatar
Guolin Ke committed
457
458
459
460
461
462
463
464
465
466
  // write feature names
  for (int i = 0; i < dataset->num_total_features_; ++i) {
    int str_len = *(reinterpret_cast<const int*>(mem_ptr));
    mem_ptr += sizeof(int);
    std::stringstream str_buf;
    for (int j = 0; j < str_len; ++j) {
      char tmp_char = *(reinterpret_cast<const char*>(mem_ptr));
      mem_ptr += sizeof(char);
      str_buf << tmp_char;
    }
Guolin Ke's avatar
Guolin Ke committed
467
    dataset->feature_names_.emplace_back(str_buf.str());
Guolin Ke's avatar
Guolin Ke committed
468
  }
469
470
471
472
473
474
475
476
477
478
479
480
481
  // get forced_bin_bounds_
  dataset->forced_bin_bounds_ = std::vector<std::vector<double>>(dataset->num_total_features_, std::vector<double>());
  for (int i = 0; i < dataset->num_total_features_; ++i) {
    int num_bounds = *(reinterpret_cast<const int*>(mem_ptr));
    mem_ptr += sizeof(int);
    dataset->forced_bin_bounds_[i] = std::vector<double>();
    const double* tmp_ptr_forced_bounds = reinterpret_cast<const double*>(mem_ptr);
    for (int j = 0; j < num_bounds; ++j) {
      double bound = tmp_ptr_forced_bounds[j];
      dataset->forced_bin_bounds_[i].push_back(bound);
    }
    mem_ptr += num_bounds * sizeof(double);
  }
Guolin Ke's avatar
Guolin Ke committed
482
483

  // read size of meta data
484
  read_cnt = reader->Read(buffer.data(), sizeof(size_t));
Guolin Ke's avatar
Guolin Ke committed
485

486
  if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
487
488
489
    Log::Fatal("Binary file error: meta data has the wrong size");
  }

Guolin Ke's avatar
Guolin Ke committed
490
  size_t size_of_metadata = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
491
492
493
494

  // re-allocate space if not enough
  if (size_of_metadata > buffer_size) {
    buffer_size = size_of_metadata;
Guolin Ke's avatar
Guolin Ke committed
495
    buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
496
497
  }
  //  read meta data
498
  read_cnt = reader->Read(buffer.data(), size_of_metadata);
Guolin Ke's avatar
Guolin Ke committed
499
500
501
502
503

  if (read_cnt != size_of_metadata) {
    Log::Fatal("Binary file error: meta data is incorrect");
  }
  // load meta data
Guolin Ke's avatar
Guolin Ke committed
504
  dataset->metadata_.LoadFromMemory(buffer.data());
Guolin Ke's avatar
Guolin Ke committed
505

506
507
  *num_global_data = dataset->num_data_;
  used_data_indices->clear();
Guolin Ke's avatar
Guolin Ke committed
508
  // sample local used data if need to partition
Guolin Ke's avatar
Guolin Ke committed
509
  if (num_machines > 1 && !config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
510
511
512
513
    const data_size_t* query_boundaries = dataset->metadata_.query_boundaries();
    if (query_boundaries == nullptr) {
      // if not contain query file, minimal sample unit is one record
      for (data_size_t i = 0; i < dataset->num_data_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
514
        if (random_.NextShort(0, num_machines) == rank) {
515
          used_data_indices->push_back(i);
Guolin Ke's avatar
Guolin Ke committed
516
517
518
519
520
521
522
523
524
        }
      }
    } else {
      // if contain query file, minimal sample unit is one query
      data_size_t num_queries = dataset->metadata_.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      for (data_size_t i = 0; i < dataset->num_data_; ++i) {
        if (qid >= num_queries) {
525
526
          Log::Fatal("Current query exceeds the range of the query file,\n"
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
527
528
529
530
        }
        if (i >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
531
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
532
533
534
535
536
            is_query_used = true;
          }
          ++qid;
        }
        if (is_query_used) {
537
          used_data_indices->push_back(i);
Guolin Ke's avatar
Guolin Ke committed
538
539
540
        }
      }
    }
541
    dataset->num_data_ = static_cast<data_size_t>((*used_data_indices).size());
Guolin Ke's avatar
Guolin Ke committed
542
  }
543
  dataset->metadata_.PartitionLabel(*used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
544
  // read feature data
Guolin Ke's avatar
Guolin Ke committed
545
  for (int i = 0; i < dataset->num_groups_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
546
    // read feature size
547
548
    read_cnt = reader->Read(buffer.data(), sizeof(size_t));
    if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
549
550
      Log::Fatal("Binary file error: feature %d has the wrong size", i);
    }
Guolin Ke's avatar
Guolin Ke committed
551
    size_t size_of_feature = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
552
553
554
    // re-allocate space if not enough
    if (size_of_feature > buffer_size) {
      buffer_size = size_of_feature;
Guolin Ke's avatar
Guolin Ke committed
555
      buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
556
557
    }

558
    read_cnt = reader->Read(buffer.data(), size_of_feature);
Guolin Ke's avatar
Guolin Ke committed
559
560
561
562

    if (read_cnt != size_of_feature) {
      Log::Fatal("Binary file error: feature %d is incorrect, read count: %d", i, read_cnt);
    }
Guolin Ke's avatar
Guolin Ke committed
563
    dataset->feature_groups_.emplace_back(std::unique_ptr<FeatureGroup>(
564
565
      new FeatureGroup(buffer.data(),
                       *num_global_data,
566
                       *used_data_indices)));
Guolin Ke's avatar
Guolin Ke committed
567
  }
Guolin Ke's avatar
Guolin Ke committed
568
  dataset->feature_groups_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
569
  dataset->is_finish_load_ = true;
Guolin Ke's avatar
Guolin Ke committed
570
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
571
572
}

573

574
575
576
Dataset* DatasetLoader::CostructFromSampleData(double** sample_values,
                                               int** sample_indices, int num_col, const int* num_per_col,
                                               size_t total_sample_size, data_size_t num_data) {
577
578
579
580
581
  int num_total_features = num_col;
  if (Network::num_machines() > 1) {
    num_total_features = Network::GlobalSyncUpByMax(num_total_features);
  }
  std::vector<std::unique_ptr<BinMapper>> bin_mappers(num_total_features);
582
583
  // fill feature_names_ if not header
  if (feature_names_.empty()) {
584
    for (int i = 0; i < num_col; ++i) {
585
586
587
588
589
      std::stringstream str_buf;
      str_buf << "Column_" << i;
      feature_names_.push_back(str_buf.str());
    }
  }
Belinda Trotta's avatar
Belinda Trotta committed
590
591
592
593
  if (!config_.max_bin_by_feature.empty()) {
    CHECK(static_cast<size_t>(num_col) == config_.max_bin_by_feature.size());
    CHECK(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())) > 1);
  }
594
595
596
597
598

  // get forced split
  std::string forced_bins_path = config_.forcedbins_filename;
  std::vector<std::vector<double>> forced_bin_bounds = DatasetLoader::GetForcedBins(forced_bins_path, num_col, categorical_features_);

Guolin Ke's avatar
Guolin Ke committed
599
  const data_size_t filter_cnt = static_cast<data_size_t>(
Guolin Ke's avatar
Guolin Ke committed
600
    static_cast<double>(config_.min_data_in_leaf * total_sample_size) / num_data);
601
602
603
604
605
606
607
608
609
610
611
612
613
  if (Network::num_machines() == 1) {
    // if only one machine, find bin locally
    OMP_INIT_EX();
    #pragma omp parallel for schedule(guided)
    for (int i = 0; i < num_col; ++i) {
      OMP_LOOP_EX_BEGIN();
      if (ignore_features_.count(i) > 0) {
        bin_mappers[i] = nullptr;
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(i)) {
        bin_type = BinType::CategoricalBin;
614
615
616
617
        bool feat_is_unconstrained = ((config_.monotone_constraints.size() == 0) || (config_.monotone_constraints[i] == 0));
        if (!feat_is_unconstrained) {
            Log::Fatal("The output cannot be monotone with respect to categorical features");
        }
618
619
      }
      bin_mappers[i].reset(new BinMapper());
Belinda Trotta's avatar
Belinda Trotta committed
620
621
      if (config_.max_bin_by_feature.empty()) {
        bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
622
                                config_.max_bin, config_.min_data_in_bin, filter_cnt,
623
624
                                bin_type, config_.use_missing, config_.zero_as_missing,
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
625
626
      } else {
        bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
627
628
                                config_.max_bin_by_feature[i], config_.min_data_in_bin,
                                filter_cnt, bin_type, config_.use_missing,
629
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
630
      }
631
632
633
634
635
636
637
638
639
640
641
642
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();
  } else {
    // if have multi-machines, need to find bin distributed
    // different machines will find bin for different features
    int num_machines = Network::num_machines();
    int rank = Network::rank();
    // start and len will store the process feature indices for different machines
    // machine i will find bins for features in [ start[i], start[i] + len[i] )
    std::vector<int> start(num_machines);
    std::vector<int> len(num_machines);
643
    int step = (num_total_features + num_machines - 1) / num_machines;
644
645
646
647
    if (step < 1) { step = 1; }

    start[0] = 0;
    for (int i = 0; i < num_machines - 1; ++i) {
648
      len[i] = std::min(step, num_total_features - start[i]);
649
650
      start[i + 1] = start[i] + len[i];
    }
651
    len[num_machines - 1] = num_total_features - start[num_machines - 1];
652
653
654
655
656
657
658
659
660
661
662
663
    OMP_INIT_EX();
    #pragma omp parallel for schedule(guided)
    for (int i = 0; i < len[rank]; ++i) {
      OMP_LOOP_EX_BEGIN();
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(start[rank] + i)) {
        bin_type = BinType::CategoricalBin;
      }
      bin_mappers[i].reset(new BinMapper());
664
665
666
      if (num_col <= start[rank] + i) {
        continue;
      }
Belinda Trotta's avatar
Belinda Trotta committed
667
      if (config_.max_bin_by_feature.empty()) {
668
669
        bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
                                total_sample_size, config_.max_bin, config_.min_data_in_bin,
670
                                filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing,
671
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
672
      } else {
673
674
675
        bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
                                total_sample_size, config_.max_bin_by_feature[start[rank] + i],
                                config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing,
676
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
677
      }
678
679
      OMP_LOOP_EX_END();
    }
680
    comm_size_t self_buf_size = 0;
681
    for (int i = 0; i < len[rank]; ++i) {
682
683
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
684
      }
685
      self_buf_size += static_cast<comm_size_t>(bin_mappers[i]->SizesInByte());
Guolin Ke's avatar
Guolin Ke committed
686
    }
687
688
    std::vector<char> input_buffer(self_buf_size);
    auto cp_ptr = input_buffer.data();
689
690
691
692
    for (int i = 0; i < len[rank]; ++i) {
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
693
694
      bin_mappers[i]->CopyTo(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
695
696
697
      // free
      bin_mappers[i].reset(nullptr);
    }
698
699
700
701
    std::vector<comm_size_t> size_len = Network::GlobalArray(self_buf_size);
    std::vector<comm_size_t> size_start(num_machines, 0);
    for (int i = 1; i < num_machines; ++i) {
      size_start[i] = size_start[i - 1] + size_len[i - 1];
702
    }
703
704
    comm_size_t total_buffer_size = size_start[num_machines - 1] + size_len[num_machines - 1];
    std::vector<char> output_buffer(total_buffer_size);
705
    // gather global feature bin mappers
706
707
    Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), total_buffer_size);
    cp_ptr = output_buffer.data();
708
    // restore features bins from buffer
709
    for (int i = 0; i < num_total_features; ++i) {
710
711
712
713
714
      if (ignore_features_.count(i) > 0) {
        bin_mappers[i] = nullptr;
        continue;
      }
      bin_mappers[i].reset(new BinMapper());
715
716
      bin_mappers[i]->CopyFrom(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
717
    }
Guolin Ke's avatar
Guolin Ke committed
718
  }
Guolin Ke's avatar
Guolin Ke committed
719
  auto dataset = std::unique_ptr<Dataset>(new Dataset(num_data));
720
  dataset->Construct(&bin_mappers, num_total_features, forced_bin_bounds, sample_indices, num_per_col, num_col, total_sample_size, config_);
721
  dataset->set_feature_names(feature_names_);
Guolin Ke's avatar
Guolin Ke committed
722
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
723
}
Guolin Ke's avatar
Guolin Ke committed
724
725
726
727
728
729


// ---- private functions ----

void DatasetLoader::CheckDataset(const Dataset* dataset) {
  if (dataset->num_data_ <= 0) {
Guolin Ke's avatar
Guolin Ke committed
730
    Log::Fatal("Data file %s is empty", dataset->data_filename_.c_str());
Guolin Ke's avatar
Guolin Ke committed
731
  }
732
733
  if (dataset->feature_names_.size() != static_cast<size_t>(dataset->num_total_features_)) {
    Log::Fatal("Size of feature name error, should be %d, got %d", dataset->num_total_features_,
734
               static_cast<int>(dataset->feature_names_.size()));
735
  }
Guolin Ke's avatar
Guolin Ke committed
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
  bool is_feature_order_by_group = true;
  int last_group = -1;
  int last_sub_feature = -1;
  // if features are ordered, not need to use hist_buf
  for (int i = 0; i < dataset->num_features_; ++i) {
    int group = dataset->feature2group_[i];
    int sub_feature = dataset->feature2subfeature_[i];
    if (group < last_group) {
      is_feature_order_by_group = false;
    } else if (group == last_group) {
      if (sub_feature <= last_sub_feature) {
        is_feature_order_by_group = false;
        break;
      }
    }
    last_group = group;
    last_sub_feature = sub_feature;
  }
  if (!is_feature_order_by_group) {
755
    Log::Fatal("Features in dataset should be ordered by group");
Guolin Ke's avatar
Guolin Ke committed
756
  }
Guolin Ke's avatar
Guolin Ke committed
757
758
759
}

std::vector<std::string> DatasetLoader::LoadTextDataToMemory(const char* filename, const Metadata& metadata,
760
761
                                                             int rank, int num_machines, int* num_global_data,
                                                             std::vector<data_size_t>* used_data_indices) {
762
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
763
  used_data_indices->clear();
Guolin Ke's avatar
Guolin Ke committed
764
  if (num_machines == 1 || config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
765
766
767
768
769
770
771
772
773
    // read all lines
    *num_global_data = text_reader.ReadAllLines();
  } else {  // need partition data
            // get query data
    const data_size_t* query_boundaries = metadata.query_boundaries();

    if (query_boundaries == nullptr) {
      // if not contain query data, minimal sample unit is one record
      *num_global_data = text_reader.ReadAndFilterLines([this, rank, num_machines](data_size_t) {
Guolin Ke's avatar
Guolin Ke committed
774
        if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
775
776
777
778
779
780
781
782
783
784
785
786
787
788
          return true;
        } else {
          return false;
        }
      }, used_data_indices);
    } else {
      // if contain query data, minimal sample unit is one query
      data_size_t num_queries = metadata.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      *num_global_data = text_reader.ReadAndFilterLines(
        [this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries]
      (data_size_t line_idx) {
        if (qid >= num_queries) {
789
790
          Log::Fatal("Current query exceeds the range of the query file,\n"
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
791
792
793
794
        }
        if (line_idx >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
795
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
796
797
798
799
800
801
802
803
804
805
806
807
            is_query_used = true;
          }
          ++qid;
        }
        return is_query_used;
      }, used_data_indices);
    }
  }
  return std::move(text_reader.Lines());
}

std::vector<std::string> DatasetLoader::SampleTextDataFromMemory(const std::vector<std::string>& data) {
Guolin Ke's avatar
Guolin Ke committed
808
  int sample_cnt = config_.bin_construct_sample_cnt;
809
810
  if (static_cast<size_t>(sample_cnt) > data.size()) {
    sample_cnt = static_cast<int>(data.size());
811
  }
812
  auto sample_indices = random_.Sample(static_cast<int>(data.size()), sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
813
  std::vector<std::string> out(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
814
815
  for (size_t i = 0; i < sample_indices.size(); ++i) {
    const size_t idx = sample_indices[i];
Guolin Ke's avatar
Guolin Ke committed
816
    out[i] = data[idx];
Guolin Ke's avatar
Guolin Ke committed
817
818
819
820
  }
  return out;
}

821
822
823
std::vector<std::string> DatasetLoader::SampleTextDataFromFile(const char* filename, const Metadata& metadata,
                                                               int rank, int num_machines, int* num_global_data,
                                                               std::vector<data_size_t>* used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
824
  const data_size_t sample_cnt = static_cast<data_size_t>(config_.bin_construct_sample_cnt);
825
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
826
  std::vector<std::string> out_data;
Guolin Ke's avatar
Guolin Ke committed
827
  if (num_machines == 1 || config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
828
    *num_global_data = static_cast<data_size_t>(text_reader.SampleFromFile(&random_, sample_cnt, &out_data));
Guolin Ke's avatar
Guolin Ke committed
829
830
831
832
833
834
835
  } else {  // need partition data
            // get query data
    const data_size_t* query_boundaries = metadata.query_boundaries();
    if (query_boundaries == nullptr) {
      // if not contain query file, minimal sample unit is one record
      *num_global_data = text_reader.SampleAndFilterFromFile([this, rank, num_machines]
      (data_size_t) {
Guolin Ke's avatar
Guolin Ke committed
836
        if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
837
838
839
840
          return true;
        } else {
          return false;
        }
Guolin Ke's avatar
Guolin Ke committed
841
      }, used_data_indices, &random_, sample_cnt, &out_data);
Guolin Ke's avatar
Guolin Ke committed
842
843
844
845
846
847
848
849
850
    } else {
      // if contain query file, minimal sample unit is one query
      data_size_t num_queries = metadata.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      *num_global_data = text_reader.SampleAndFilterFromFile(
        [this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries]
      (data_size_t line_idx) {
        if (qid >= num_queries) {
851
852
          Log::Fatal("Query id exceeds the range of the query file, "
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
853
854
855
856
        }
        if (line_idx >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
857
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
858
859
860
861
862
            is_query_used = true;
          }
          ++qid;
        }
        return is_query_used;
Guolin Ke's avatar
Guolin Ke committed
863
      }, used_data_indices, &random_, sample_cnt, &out_data);
Guolin Ke's avatar
Guolin Ke committed
864
865
866
867
868
    }
  }
  return out_data;
}

869
870
871
void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
                                                    const std::vector<std::string>& sample_data,
                                                    const Parser* parser, Dataset* dataset) {
Guolin Ke's avatar
Guolin Ke committed
872
  std::vector<std::vector<double>> sample_values;
Guolin Ke's avatar
Guolin Ke committed
873
  std::vector<std::vector<int>> sample_indices;
Guolin Ke's avatar
Guolin Ke committed
874
875
  std::vector<std::pair<int, double>> oneline_features;
  double label;
Guolin Ke's avatar
Guolin Ke committed
876
  for (int i = 0; i < static_cast<int>(sample_data.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
877
878
879
880
    oneline_features.clear();
    // parse features
    parser->ParseOneLine(sample_data[i].c_str(), &oneline_features, &label);
    for (std::pair<int, double>& inner_data : oneline_features) {
881
      if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
Guolin Ke's avatar
Guolin Ke committed
882
883
        sample_values.resize(inner_data.first + 1);
        sample_indices.resize(inner_data.first + 1);
884
      }
Guolin Ke's avatar
Guolin Ke committed
885
      if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
Guolin Ke's avatar
Guolin Ke committed
886
887
        sample_values[inner_data.first].emplace_back(inner_data.second);
        sample_indices[inner_data.first].emplace_back(i);
Guolin Ke's avatar
Guolin Ke committed
888
889
890
891
      }
    }
  }

Guolin Ke's avatar
Guolin Ke committed
892
  dataset->feature_groups_.clear();
893
894
895
896
897
898
  dataset->num_total_features_ = std::max(static_cast<int>(sample_values.size()), parser->NumFeatures());
  if (num_machines > 1) {
    dataset->num_total_features_ = Network::GlobalSyncUpByMax(dataset->num_total_features_);
  }
  if (!feature_names_.empty()) {
    CHECK(dataset->num_total_features_ == static_cast<int>(feature_names_.size()));
899
  }
Guolin Ke's avatar
Guolin Ke committed
900

Belinda Trotta's avatar
Belinda Trotta committed
901
902
903
904
905
  if (!config_.max_bin_by_feature.empty()) {
    CHECK(static_cast<size_t>(dataset->num_total_features_) == config_.max_bin_by_feature.size());
    CHECK(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())) > 1);
  }

906
907
  // get forced split
  std::string forced_bins_path = config_.forcedbins_filename;
908
909
  std::vector<std::vector<double>> forced_bin_bounds = DatasetLoader::GetForcedBins(forced_bins_path,
                                                                                    dataset->num_total_features_,
910
911
                                                                                    categorical_features_);

Guolin Ke's avatar
Guolin Ke committed
912
913
914
915
916
917
  // check the range of label_idx, weight_idx and group_idx
  CHECK(label_idx_ >= 0 && label_idx_ <= dataset->num_total_features_);
  CHECK(weight_idx_ < 0 || weight_idx_ < dataset->num_total_features_);
  CHECK(group_idx_ < 0 || group_idx_ < dataset->num_total_features_);

  // fill feature_names_ if not header
Guolin Ke's avatar
Guolin Ke committed
918
  if (feature_names_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
919
920
921
922
923
924
    for (int i = 0; i < dataset->num_total_features_; ++i) {
      std::stringstream str_buf;
      str_buf << "Column_" << i;
      feature_names_.push_back(str_buf.str());
    }
  }
925
  dataset->set_feature_names(feature_names_);
Guolin Ke's avatar
Guolin Ke committed
926
  std::vector<std::unique_ptr<BinMapper>> bin_mappers(dataset->num_total_features_);
Guolin Ke's avatar
Guolin Ke committed
927
  const data_size_t filter_cnt = static_cast<data_size_t>(
Guolin Ke's avatar
Guolin Ke committed
928
    static_cast<double>(config_.min_data_in_leaf* sample_data.size()) / dataset->num_data_);
Guolin Ke's avatar
Guolin Ke committed
929
930
931
  // start find bins
  if (num_machines == 1) {
    // if only one machine, find bin locally
932
    OMP_INIT_EX();
933
    #pragma omp parallel for schedule(guided)
Guolin Ke's avatar
Guolin Ke committed
934
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
935
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
936
      if (ignore_features_.count(i) > 0) {
Guolin Ke's avatar
Guolin Ke committed
937
        bin_mappers[i] = nullptr;
Guolin Ke's avatar
Guolin Ke committed
938
939
        continue;
      }
940
941
942
943
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(i)) {
        bin_type = BinType::CategoricalBin;
      }
Guolin Ke's avatar
Guolin Ke committed
944
      bin_mappers[i].reset(new BinMapper());
Belinda Trotta's avatar
Belinda Trotta committed
945
946
      if (config_.max_bin_by_feature.empty()) {
        bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
947
                                sample_data.size(), config_.max_bin, config_.min_data_in_bin,
948
949
                                filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing,
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
950
951
      } else {
        bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
952
953
                                sample_data.size(), config_.max_bin_by_feature[i],
                                config_.min_data_in_bin, filter_cnt, bin_type, config_.use_missing,
954
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
955
      }
956
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
957
    }
958
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
959
960
  } else {
    // start and len will store the process feature indices for different machines
961
    // machine i will find bins for features in [ start[i], start[i] + len[i] )
Guolin Ke's avatar
Guolin Ke committed
962
963
    std::vector<int> start(num_machines);
    std::vector<int> len(num_machines);
964
    int step = (dataset->num_total_features_ + num_machines - 1) / num_machines;
Guolin Ke's avatar
Guolin Ke committed
965
966
967
968
    if (step < 1) { step = 1; }

    start[0] = 0;
    for (int i = 0; i < num_machines - 1; ++i) {
969
      len[i] = std::min(step, dataset->num_total_features_ - start[i]);
Guolin Ke's avatar
Guolin Ke committed
970
971
      start[i + 1] = start[i] + len[i];
    }
972
    len[num_machines - 1] = dataset->num_total_features_ - start[num_machines - 1];
973
    OMP_INIT_EX();
974
    #pragma omp parallel for schedule(guided)
975
    for (int i = 0; i < len[rank]; ++i) {
976
      OMP_LOOP_EX_BEGIN();
977
978
979
980
981
982
983
984
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(start[rank] + i)) {
        bin_type = BinType::CategoricalBin;
      }
      bin_mappers[i].reset(new BinMapper());
Nikita Titov's avatar
Nikita Titov committed
985
      if (static_cast<int>(sample_values.size()) <= start[rank] + i) {
986
987
        continue;
      }
Belinda Trotta's avatar
Belinda Trotta committed
988
      if (config_.max_bin_by_feature.empty()) {
989
        bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
Belinda Trotta's avatar
Belinda Trotta committed
990
                                static_cast<int>(sample_values[start[rank] + i].size()),
991
                                sample_data.size(), config_.max_bin, config_.min_data_in_bin,
992
                                filter_cnt, bin_type, config_.use_missing, config_.zero_as_missing,
993
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
994
      } else {
995
        bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
Belinda Trotta's avatar
Belinda Trotta committed
996
                                static_cast<int>(sample_values[start[rank] + i].size()),
997
998
                                sample_data.size(), config_.max_bin_by_feature[i],
                                config_.min_data_in_bin, filter_cnt, bin_type,
999
                                config_.use_missing, config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
1000
      }
1001
      OMP_LOOP_EX_END();
1002
    }
1003
    OMP_THROW_EX();
1004
    comm_size_t self_buf_size = 0;
Guolin Ke's avatar
Guolin Ke committed
1005
    for (int i = 0; i < len[rank]; ++i) {
1006
1007
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
Guolin Ke's avatar
Guolin Ke committed
1008
      }
1009
      self_buf_size += static_cast<comm_size_t>(bin_mappers[i]->SizesInByte());
Guolin Ke's avatar
Guolin Ke committed
1010
    }
1011
1012
    std::vector<char> input_buffer(self_buf_size);
    auto cp_ptr = input_buffer.data();
Guolin Ke's avatar
Guolin Ke committed
1013
    for (int i = 0; i < len[rank]; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1014
1015
1016
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
1017
1018
      bin_mappers[i]->CopyTo(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
1019
1020
      // free
      bin_mappers[i].reset(nullptr);
Guolin Ke's avatar
Guolin Ke committed
1021
    }
1022
1023
1024
1025
    std::vector<comm_size_t> size_len = Network::GlobalArray(self_buf_size);
    std::vector<comm_size_t> size_start(num_machines, 0);
    for (int i = 1; i < num_machines; ++i) {
      size_start[i] = size_start[i - 1] + size_len[i - 1];
Guolin Ke's avatar
Guolin Ke committed
1026
    }
1027
1028
    comm_size_t total_buffer_size = size_start[num_machines - 1] + size_len[num_machines - 1];
    std::vector<char> output_buffer(total_buffer_size);
Guolin Ke's avatar
Guolin Ke committed
1029
    // gather global feature bin mappers
1030
1031
    Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), total_buffer_size);
    cp_ptr = output_buffer.data();
Guolin Ke's avatar
Guolin Ke committed
1032
    // restore features bins from buffer
1033
    for (int i = 0; i < dataset->num_total_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1034
      if (ignore_features_.count(i) > 0) {
Guolin Ke's avatar
Guolin Ke committed
1035
        bin_mappers[i] = nullptr;
Guolin Ke's avatar
Guolin Ke committed
1036
1037
        continue;
      }
Guolin Ke's avatar
Guolin Ke committed
1038
      bin_mappers[i].reset(new BinMapper());
1039
1040
      bin_mappers[i]->CopyFrom(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
Guolin Ke's avatar
Guolin Ke committed
1041
1042
    }
  }
Guolin Ke's avatar
Guolin Ke committed
1043
  sample_values.clear();
1044
  dataset->Construct(&bin_mappers, dataset->num_total_features_, forced_bin_bounds, Common::Vector2Ptr<int>(&sample_indices).data(),
1045
                     Common::VectorSize<int>(sample_indices).data(), static_cast<int>(sample_indices.size()), sample_data.size(), config_);
Guolin Ke's avatar
Guolin Ke committed
1046
1047
1048
}

/*! \brief Extract local features from memory */
Guolin Ke's avatar
Guolin Ke committed
1049
void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>* text_data, const Parser* parser, Dataset* dataset) {
Guolin Ke's avatar
Guolin Ke committed
1050
1051
  std::vector<std::pair<int, double>> oneline_features;
  double tmp_label = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
1052
  auto& ref_text_data = *text_data;
Guolin Ke's avatar
Guolin Ke committed
1053
  if (predict_fun_ == nullptr) {
1054
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1055
    // if doesn't need to prediction with initial model
1056
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
Guolin Ke's avatar
Guolin Ke committed
1057
    for (data_size_t i = 0; i < dataset->num_data_; ++i) {
1058
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1059
1060
1061
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
Guolin Ke's avatar
Guolin Ke committed
1062
      parser->ParseOneLine(ref_text_data[i].c_str(), &oneline_features, &tmp_label);
Guolin Ke's avatar
Guolin Ke committed
1063
      // set label
1064
      dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1065
      // free processed line:
Guolin Ke's avatar
Guolin Ke committed
1066
      ref_text_data[i].clear();
Guolin Ke's avatar
Guolin Ke committed
1067
1068
1069
1070
      // shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
      // text_reader_->Lines()[i].shrink_to_fit();
      // push data
      for (auto& inner_data : oneline_features) {
1071
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1072
1073
1074
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1075
1076
1077
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
          dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
Guolin Ke's avatar
Guolin Ke committed
1078
1079
        } else {
          if (inner_data.first == weight_idx_) {
1080
            dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1081
1082
1083
1084
1085
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
1086
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1087
    }
1088
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1089
  } else {
1090
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1091
    // if need to prediction with initial model
1092
    std::vector<double> init_score(dataset->num_data_ * num_class_);
1093
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
Guolin Ke's avatar
Guolin Ke committed
1094
    for (data_size_t i = 0; i < dataset->num_data_; ++i) {
1095
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1096
1097
1098
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
Guolin Ke's avatar
Guolin Ke committed
1099
      parser->ParseOneLine(ref_text_data[i].c_str(), &oneline_features, &tmp_label);
Guolin Ke's avatar
Guolin Ke committed
1100
      // set initial score
Guolin Ke's avatar
Guolin Ke committed
1101
1102
      std::vector<double> oneline_init_score(num_class_);
      predict_fun_(oneline_features, oneline_init_score.data());
1103
      for (int k = 0; k < num_class_; ++k) {
1104
        init_score[k * dataset->num_data_ + i] = static_cast<double>(oneline_init_score[k]);
Guolin Ke's avatar
Guolin Ke committed
1105
1106
      }
      // set label
1107
      dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1108
1109
1110
1111
1112
1113
      // free processed line:
      text_data[i].clear();
      // shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
      // text_reader_->Lines()[i].shrink_to_fit();
      // push data
      for (auto& inner_data : oneline_features) {
1114
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1115
1116
1117
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1118
1119
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
1120
          dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
Guolin Ke's avatar
Guolin Ke committed
1121
1122
        } else {
          if (inner_data.first == weight_idx_) {
1123
            dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1124
1125
1126
1127
1128
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
1129
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1130
    }
1131
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1132
    // metadata_ will manage space of init_score
1133
    dataset->metadata_.SetInitScore(init_score.data(), dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1134
  }
Guolin Ke's avatar
Guolin Ke committed
1135
  dataset->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1136
  // text data can be free after loaded feature values
Guolin Ke's avatar
Guolin Ke committed
1137
  text_data->clear();
Guolin Ke's avatar
Guolin Ke committed
1138
1139
1140
}

/*! \brief Extract local features from file */
1141
1142
void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser* parser,
                                            const std::vector<data_size_t>& used_data_indices, Dataset* dataset) {
1143
  std::vector<double> init_score;
Guolin Ke's avatar
Guolin Ke committed
1144
  if (predict_fun_ != nullptr) {
1145
    init_score = std::vector<double>(dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1146
1147
1148
1149
1150
1151
  }
  std::function<void(data_size_t, const std::vector<std::string>&)> process_fun =
    [this, &init_score, &parser, &dataset]
  (data_size_t start_idx, const std::vector<std::string>& lines) {
    std::vector<std::pair<int, double>> oneline_features;
    double tmp_label = 0.0f;
1152
    OMP_INIT_EX();
1153
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
Guolin Ke's avatar
Guolin Ke committed
1154
    for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) {
1155
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1156
1157
1158
1159
1160
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
      parser->ParseOneLine(lines[i].c_str(), &oneline_features, &tmp_label);
      // set initial score
Guolin Ke's avatar
Guolin Ke committed
1161
      if (!init_score.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1162
1163
        std::vector<double> oneline_init_score(num_class_);
        predict_fun_(oneline_features, oneline_init_score.data());
1164
        for (int k = 0; k < num_class_; ++k) {
1165
          init_score[k * dataset->num_data_ + start_idx + i] = static_cast<double>(oneline_init_score[k]);
Guolin Ke's avatar
Guolin Ke committed
1166
1167
1168
        }
      }
      // set label
1169
      dataset->metadata_.SetLabelAt(start_idx + i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1170
1171
      // push data
      for (auto& inner_data : oneline_features) {
1172
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1173
1174
1175
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1176
1177
1178
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
          dataset->feature_groups_[group]->PushData(tid, sub_feature, start_idx + i, inner_data.second);
Guolin Ke's avatar
Guolin Ke committed
1179
1180
        } else {
          if (inner_data.first == weight_idx_) {
1181
            dataset->metadata_.SetWeightAt(start_idx + i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1182
1183
1184
1185
1186
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(start_idx + i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
1187
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1188
    }
1189
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1190
  };
1191
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
1192
  if (!used_data_indices.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1193
1194
1195
1196
1197
1198
1199
1200
    // only need part of data
    text_reader.ReadPartAndProcessParallel(used_data_indices, process_fun);
  } else {
    // need full data
    text_reader.ReadAllAndProcessParallel(process_fun);
  }

  // metadata_ will manage space of init_score
Guolin Ke's avatar
Guolin Ke committed
1201
  if (!init_score.empty()) {
1202
    dataset->metadata_.SetInitScore(init_score.data(), dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1203
  }
Guolin Ke's avatar
Guolin Ke committed
1204
  dataset->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1205
1206
1207
}

/*! \brief Check can load from binary file */
1208
std::string DatasetLoader::CheckCanLoadFromBin(const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
1209
1210
1211
  std::string bin_filename(filename);
  bin_filename.append(".bin");

1212
  auto reader = VirtualFileReader::Make(bin_filename.c_str());
Guolin Ke's avatar
Guolin Ke committed
1213

1214
  if (!reader->Init()) {
1215
    bin_filename = std::string(filename);
1216
1217
    reader = VirtualFileReader::Make(bin_filename.c_str());
    if (!reader->Init()) {
1218
      Log::Fatal("Cannot open data file %s", bin_filename.c_str());
1219
    }
1220
  }
1221
1222
1223
1224
1225

  size_t buffer_size = 256;
  auto buffer = std::vector<char>(buffer_size);
  // read size of token
  size_t size_of_token = std::strlen(Dataset::binary_file_token);
1226
  size_t read_cnt = reader->Read(buffer.data(), size_of_token);
1227
1228
  if (read_cnt == size_of_token
      && std::string(buffer.data()) == std::string(Dataset::binary_file_token)) {
1229
    return bin_filename;
Guolin Ke's avatar
Guolin Ke committed
1230
  } else {
1231
    return std::string();
Guolin Ke's avatar
Guolin Ke committed
1232
1233
1234
  }
}

1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272


std::vector<std::vector<double>> DatasetLoader::GetForcedBins(std::string forced_bins_path, int num_total_features,
                                                              const std::unordered_set<int>& categorical_features) {
  std::vector<std::vector<double>> forced_bins(num_total_features, std::vector<double>());
  if (forced_bins_path != "") {
    std::ifstream forced_bins_stream(forced_bins_path.c_str());
    if (forced_bins_stream.fail()) {
      Log::Warning("Could not open %s. Will ignore.", forced_bins_path.c_str());
    } else {
      std::stringstream buffer;
      buffer << forced_bins_stream.rdbuf();
      std::string err;
      Json forced_bins_json = Json::parse(buffer.str(), err);
      CHECK(forced_bins_json.is_array());
      std::vector<Json> forced_bins_arr = forced_bins_json.array_items();
      for (size_t i = 0; i < forced_bins_arr.size(); ++i) {
        int feature_num = forced_bins_arr[i]["feature"].int_value();
        CHECK(feature_num < num_total_features);
        if (categorical_features.count(feature_num)) {
          Log::Warning("Feature %d is categorical. Will ignore forced bins for this  feature.", feature_num);
        } else {
          std::vector<Json> bounds_arr = forced_bins_arr[i]["bin_upper_bound"].array_items();
          for (size_t j = 0; j < bounds_arr.size(); ++j) {
            forced_bins[feature_num].push_back(bounds_arr[j].number_value());
          }
        }
      }
      // remove duplicates
      for (int i = 0; i < num_total_features; ++i) {
        auto new_end = std::unique(forced_bins[i].begin(), forced_bins[i].end());
        forced_bins[i].erase(new_end, forced_bins[i].end());
      }
    }
  }
  return forced_bins;
}

1273
}  // namespace LightGBM