"vscode:/vscode.git/clone" did not exist on "41c366ee6b69571202ffbdca5965219cb6cf0ac4"
dataset_loader.cpp 52.8 KB
Newer Older
1
2
3
4
/*!
 * Copyright (c) 2016 Microsoft Corporation. All rights reserved.
 * Licensed under the MIT License. See LICENSE file in the project root for license information.
 */
Guolin Ke's avatar
Guolin Ke committed
5
6
#include <LightGBM/dataset_loader.h>

Guolin Ke's avatar
Guolin Ke committed
7
#include <LightGBM/network.h>
8
#include <LightGBM/utils/array_args.h>
9
#include <LightGBM/utils/json11.h>
10
11
#include <LightGBM/utils/log.h>
#include <LightGBM/utils/openmp_wrapper.h>
Guolin Ke's avatar
Guolin Ke committed
12

13
14
#include <fstream>

Guolin Ke's avatar
Guolin Ke committed
15
16
namespace LightGBM {

17
18
using json11::Json;

Guolin Ke's avatar
Guolin Ke committed
19
20
DatasetLoader::DatasetLoader(const Config& io_config, const PredictFunction& predict_fun, int num_class, const char* filename)
  :config_(io_config), random_(config_.data_random_seed), predict_fun_(predict_fun), num_class_(num_class) {
Guolin Ke's avatar
Guolin Ke committed
21
22
23
24
  label_idx_ = 0;
  weight_idx_ = NO_SPECIFIC;
  group_idx_ = NO_SPECIFIC;
  SetHeader(filename);
Guolin Ke's avatar
Guolin Ke committed
25
26
27
28
29
}

DatasetLoader::~DatasetLoader() {
}

Guolin Ke's avatar
Guolin Ke committed
30
void DatasetLoader::SetHeader(const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
31
  std::unordered_map<std::string, int> name2idx;
Guolin Ke's avatar
Guolin Ke committed
32
33
  std::string name_prefix("name:");
  if (filename != nullptr) {
Guolin Ke's avatar
Guolin Ke committed
34
    TextReader<data_size_t> text_reader(filename, config_.header);
Guolin Ke's avatar
Guolin Ke committed
35

Guolin Ke's avatar
Guolin Ke committed
36
    // get column names
Guolin Ke's avatar
Guolin Ke committed
37
    if (config_.header) {
Guolin Ke's avatar
Guolin Ke committed
38
      std::string first_line = text_reader.first_line();
39
      feature_names_ = Common::Split(first_line.c_str(), "\t,");
Guolin Ke's avatar
Guolin Ke committed
40
41
    }

Guolin Ke's avatar
Guolin Ke committed
42
    // load label idx first
Guolin Ke's avatar
Guolin Ke committed
43
44
45
    if (config_.label_column.size() > 0) {
      if (Common::StartsWith(config_.label_column, name_prefix)) {
        std::string name = config_.label_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
46
47
48
49
50
51
52
53
54
55
        label_idx_ = -1;
        for (int i = 0; i < static_cast<int>(feature_names_.size()); ++i) {
          if (name == feature_names_[i]) {
            label_idx_ = i;
            break;
          }
        }
        if (label_idx_ >= 0) {
          Log::Info("Using column %s as label", name.c_str());
        } else {
56
57
          Log::Fatal("Could not find label column %s in data file \n"
                     "or data file doesn't contain header", name.c_str());
Guolin Ke's avatar
Guolin Ke committed
58
        }
Guolin Ke's avatar
Guolin Ke committed
59
      } else {
Guolin Ke's avatar
Guolin Ke committed
60
        if (!Common::AtoiAndCheck(config_.label_column.c_str(), &label_idx_)) {
61
62
63
          Log::Fatal("label_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
64
65
        }
        Log::Info("Using column number %d as label", label_idx_);
Guolin Ke's avatar
Guolin Ke committed
66
67
      }
    }
Guolin Ke's avatar
Guolin Ke committed
68

Guolin Ke's avatar
Guolin Ke committed
69
    if (!feature_names_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
70
71
72
73
      // erase label column name
      feature_names_.erase(feature_names_.begin() + label_idx_);
      for (size_t i = 0; i < feature_names_.size(); ++i) {
        name2idx[feature_names_[i]] = static_cast<int>(i);
Guolin Ke's avatar
Guolin Ke committed
74
      }
Guolin Ke's avatar
Guolin Ke committed
75
76
77
    }

    // load ignore columns
Guolin Ke's avatar
Guolin Ke committed
78
79
80
    if (config_.ignore_column.size() > 0) {
      if (Common::StartsWith(config_.ignore_column, name_prefix)) {
        std::string names = config_.ignore_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
81
82
83
84
85
86
87
88
89
        for (auto name : Common::Split(names.c_str(), ',')) {
          if (name2idx.count(name) > 0) {
            int tmp = name2idx[name];
            ignore_features_.emplace(tmp);
          } else {
            Log::Fatal("Could not find ignore column %s in data file", name.c_str());
          }
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
90
        for (auto token : Common::Split(config_.ignore_column.c_str(), ',')) {
Guolin Ke's avatar
Guolin Ke committed
91
92
          int tmp = 0;
          if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
93
94
95
            Log::Fatal("ignore_column is not a number,\n"
                       "if you want to use a column name,\n"
                       "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
96
97
          }
          ignore_features_.emplace(tmp);
Guolin Ke's avatar
Guolin Ke committed
98
99
100
        }
      }
    }
Guolin Ke's avatar
Guolin Ke committed
101
    // load weight idx
Guolin Ke's avatar
Guolin Ke committed
102
103
104
    if (config_.weight_column.size() > 0) {
      if (Common::StartsWith(config_.weight_column, name_prefix)) {
        std::string name = config_.weight_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
105
106
107
108
109
110
        if (name2idx.count(name) > 0) {
          weight_idx_ = name2idx[name];
          Log::Info("Using column %s as weight", name.c_str());
        } else {
          Log::Fatal("Could not find weight column %s in data file", name.c_str());
        }
Guolin Ke's avatar
Guolin Ke committed
111
      } else {
Guolin Ke's avatar
Guolin Ke committed
112
        if (!Common::AtoiAndCheck(config_.weight_column.c_str(), &weight_idx_)) {
113
114
115
          Log::Fatal("weight_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
116
117
        }
        Log::Info("Using column number %d as weight", weight_idx_);
Guolin Ke's avatar
Guolin Ke committed
118
      }
Guolin Ke's avatar
Guolin Ke committed
119
      ignore_features_.emplace(weight_idx_);
Guolin Ke's avatar
Guolin Ke committed
120
    }
Guolin Ke's avatar
Guolin Ke committed
121
    // load group idx
Guolin Ke's avatar
Guolin Ke committed
122
123
124
    if (config_.group_column.size() > 0) {
      if (Common::StartsWith(config_.group_column, name_prefix)) {
        std::string name = config_.group_column.substr(name_prefix.size());
Guolin Ke's avatar
Guolin Ke committed
125
126
127
128
129
130
131
        if (name2idx.count(name) > 0) {
          group_idx_ = name2idx[name];
          Log::Info("Using column %s as group/query id", name.c_str());
        } else {
          Log::Fatal("Could not find group/query column %s in data file", name.c_str());
        }
      } else {
Guolin Ke's avatar
Guolin Ke committed
132
        if (!Common::AtoiAndCheck(config_.group_column.c_str(), &group_idx_)) {
133
134
135
          Log::Fatal("group_column is not a number,\n"
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
Guolin Ke's avatar
Guolin Ke committed
136
137
138
139
        }
        Log::Info("Using column number %d as group/query id", group_idx_);
      }
      ignore_features_.emplace(group_idx_);
Guolin Ke's avatar
Guolin Ke committed
140
141
    }
  }
Guolin Ke's avatar
Guolin Ke committed
142
143
144
  if (config_.categorical_feature.size() > 0) {
    if (Common::StartsWith(config_.categorical_feature, name_prefix)) {
      std::string names = config_.categorical_feature.substr(name_prefix.size());
145
146
147
148
149
      for (auto name : Common::Split(names.c_str(), ',')) {
        if (name2idx.count(name) > 0) {
          int tmp = name2idx[name];
          categorical_features_.emplace(tmp);
        } else {
Guolin Ke's avatar
Guolin Ke committed
150
          Log::Fatal("Could not find categorical_feature %s in data file", name.c_str());
151
152
153
        }
      }
    } else {
Guolin Ke's avatar
Guolin Ke committed
154
      for (auto token : Common::Split(config_.categorical_feature.c_str(), ',')) {
155
156
        int tmp = 0;
        if (!Common::AtoiAndCheck(token.c_str(), &tmp)) {
Guolin Ke's avatar
Guolin Ke committed
157
          Log::Fatal("categorical_feature is not a number,\n"
158
159
                     "if you want to use a column name,\n"
                     "please add the prefix \"name:\" to the column name");
160
161
162
163
164
        }
        categorical_features_.emplace(tmp);
      }
    }
  }
Guolin Ke's avatar
Guolin Ke committed
165
166
}

167
Dataset* DatasetLoader::LoadFromFile(const char* filename, int rank, int num_machines) {
Guolin Ke's avatar
Guolin Ke committed
168
  // don't support query id in data file when training in parallel
Guolin Ke's avatar
Guolin Ke committed
169
  if (num_machines > 1 && !config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
170
    if (group_idx_ > 0) {
171
172
      Log::Fatal("Using a query id without pre-partitioning the data file is not supported for parallel training.\n"
                 "Please use an additional query file or pre-partition the data");
Guolin Ke's avatar
Guolin Ke committed
173
174
    }
  }
Guolin Ke's avatar
Guolin Ke committed
175
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
Guolin Ke's avatar
Guolin Ke committed
176
177
  data_size_t num_global_data = 0;
  std::vector<data_size_t> used_data_indices;
178
179
  auto bin_filename = CheckCanLoadFromBin(filename);
  if (bin_filename.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
180
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_));
Guolin Ke's avatar
Guolin Ke committed
181
182
183
184
    if (parser == nullptr) {
      Log::Fatal("Could not recognize data format of %s", filename);
    }
    dataset->data_filename_ = filename;
Guolin Ke's avatar
Guolin Ke committed
185
    dataset->label_idx_ = label_idx_;
186
    dataset->metadata_.Init(filename);
Guolin Ke's avatar
Guolin Ke committed
187
    if (!config_.two_round) {
Guolin Ke's avatar
Guolin Ke committed
188
      // read data to memory
189
      auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
190
191
192
193
      dataset->num_data_ = static_cast<data_size_t>(text_data.size());
      // sample data
      auto sample_data = SampleTextDataFromMemory(text_data);
      // construct feature bin mappers
Guolin Ke's avatar
Guolin Ke committed
194
      ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
195
      // initialize label
196
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
197
      // extract features
Guolin Ke's avatar
Guolin Ke committed
198
      ExtractFeaturesFromMemory(&text_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
199
200
201
202
203
204
205
206
207
208
      text_data.clear();
    } else {
      // sample data from file
      auto sample_data = SampleTextDataFromFile(filename, dataset->metadata_, rank, num_machines, &num_global_data, &used_data_indices);
      if (used_data_indices.size() > 0) {
        dataset->num_data_ = static_cast<data_size_t>(used_data_indices.size());
      } else {
        dataset->num_data_ = num_global_data;
      }
      // construct feature bin mappers
Guolin Ke's avatar
Guolin Ke committed
209
      ConstructBinMappersFromTextData(rank, num_machines, sample_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
210
      // initialize label
211
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
212
      Log::Debug("Making second pass...");
Guolin Ke's avatar
Guolin Ke committed
213
      // extract features
Guolin Ke's avatar
Guolin Ke committed
214
      ExtractFeaturesFromFile(filename, parser.get(), used_data_indices, dataset.get());
Guolin Ke's avatar
Guolin Ke committed
215
216
217
    }
  } else {
    // load data from binary file
218
    dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), rank, num_machines, &num_global_data, &used_data_indices));
Guolin Ke's avatar
Guolin Ke committed
219
220
221
222
  }
  // check meta data
  dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);
  // need to check training data
Guolin Ke's avatar
Guolin Ke committed
223
224
  CheckDataset(dataset.get());
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
225
226
227
228
}



229
Dataset* DatasetLoader::LoadFromFileAlignWithOtherDataset(const char* filename, const Dataset* train_data) {
Guolin Ke's avatar
Guolin Ke committed
230
231
  data_size_t num_global_data = 0;
  std::vector<data_size_t> used_data_indices;
Guolin Ke's avatar
Guolin Ke committed
232
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
233
234
  auto bin_filename = CheckCanLoadFromBin(filename);
  if (bin_filename.size() == 0) {
Guolin Ke's avatar
Guolin Ke committed
235
    auto parser = std::unique_ptr<Parser>(Parser::CreateParser(filename, config_.header, 0, label_idx_));
Guolin Ke's avatar
Guolin Ke committed
236
237
238
239
    if (parser == nullptr) {
      Log::Fatal("Could not recognize data format of %s", filename);
    }
    dataset->data_filename_ = filename;
Guolin Ke's avatar
Guolin Ke committed
240
    dataset->label_idx_ = label_idx_;
241
    dataset->metadata_.Init(filename);
Guolin Ke's avatar
Guolin Ke committed
242
    if (!config_.two_round) {
Guolin Ke's avatar
Guolin Ke committed
243
244
245
246
      // read data in memory
      auto text_data = LoadTextDataToMemory(filename, dataset->metadata_, 0, 1, &num_global_data, &used_data_indices);
      dataset->num_data_ = static_cast<data_size_t>(text_data.size());
      // initialize label
247
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
248
      dataset->CreateValid(train_data);
Guolin Ke's avatar
Guolin Ke committed
249
      // extract features
Guolin Ke's avatar
Guolin Ke committed
250
      ExtractFeaturesFromMemory(&text_data, parser.get(), dataset.get());
Guolin Ke's avatar
Guolin Ke committed
251
252
      text_data.clear();
    } else {
Guolin Ke's avatar
Guolin Ke committed
253
      TextReader<data_size_t> text_reader(filename, config_.header);
Guolin Ke's avatar
Guolin Ke committed
254
255
256
257
      // Get number of lines of data file
      dataset->num_data_ = static_cast<data_size_t>(text_reader.CountLine());
      num_global_data = dataset->num_data_;
      // initialize label
258
      dataset->metadata_.Init(dataset->num_data_, weight_idx_, group_idx_);
Guolin Ke's avatar
Guolin Ke committed
259
      dataset->CreateValid(train_data);
Guolin Ke's avatar
Guolin Ke committed
260
      // extract features
Guolin Ke's avatar
Guolin Ke committed
261
      ExtractFeaturesFromFile(filename, parser.get(), used_data_indices, dataset.get());
Guolin Ke's avatar
Guolin Ke committed
262
263
264
    }
  } else {
    // load data from binary file
265
    dataset.reset(LoadFromBinFile(filename, bin_filename.c_str(), 0, 1, &num_global_data, &used_data_indices));
Guolin Ke's avatar
Guolin Ke committed
266
267
268
269
  }
  // not need to check validation data
  // check meta data
  dataset->metadata_.CheckOrPartition(num_global_data, used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
270
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
271
272
}

273
274
275
Dataset* DatasetLoader::LoadFromBinFile(const char* data_filename, const char* bin_filename,
                                        int rank, int num_machines, int* num_global_data,
                                        std::vector<data_size_t>* used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
276
  auto dataset = std::unique_ptr<Dataset>(new Dataset());
277
  auto reader = VirtualFileReader::Make(bin_filename);
Guolin Ke's avatar
Guolin Ke committed
278
  dataset->data_filename_ = data_filename;
279
  if (!reader->Init()) {
Guolin Ke's avatar
Guolin Ke committed
280
281
282
283
284
    Log::Fatal("Could not read binary data from %s", bin_filename);
  }

  // buffer to read binary file
  size_t buffer_size = 16 * 1024 * 1024;
Guolin Ke's avatar
Guolin Ke committed
285
  auto buffer = std::vector<char>(buffer_size);
286

287
288
  // check token
  size_t size_of_token = std::strlen(Dataset::binary_file_token);
289
290
291
292
  size_t read_cnt = reader->Read(
      buffer.data(),
      VirtualFileWriter::AlignedSize(sizeof(char) * size_of_token));
  if (read_cnt < sizeof(char) * size_of_token) {
293
294
295
    Log::Fatal("Binary file error: token has the wrong size");
  }
  if (std::string(buffer.data()) != std::string(Dataset::binary_file_token)) {
296
    Log::Fatal("Input file is not LightGBM binary file");
297
  }
Guolin Ke's avatar
Guolin Ke committed
298
299

  // read size of header
300
  read_cnt = reader->Read(buffer.data(), sizeof(size_t));
Guolin Ke's avatar
Guolin Ke committed
301

302
  if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
303
304
305
    Log::Fatal("Binary file error: header has the wrong size");
  }

Guolin Ke's avatar
Guolin Ke committed
306
  size_t size_of_head = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
307
308
309
310

  // re-allocmate space if not enough
  if (size_of_head > buffer_size) {
    buffer_size = size_of_head;
Guolin Ke's avatar
Guolin Ke committed
311
    buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
312
313
  }
  // read header
314
  read_cnt = reader->Read(buffer.data(), size_of_head);
Guolin Ke's avatar
Guolin Ke committed
315
316
317
318
319

  if (read_cnt != size_of_head) {
    Log::Fatal("Binary file error: header is incorrect");
  }
  // get header
Guolin Ke's avatar
Guolin Ke committed
320
  const char* mem_ptr = buffer.data();
Guolin Ke's avatar
Guolin Ke committed
321
  dataset->num_data_ = *(reinterpret_cast<const data_size_t*>(mem_ptr));
322
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->num_data_));
Guolin Ke's avatar
Guolin Ke committed
323
  dataset->num_features_ = *(reinterpret_cast<const int*>(mem_ptr));
324
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->num_features_));
Guolin Ke's avatar
Guolin Ke committed
325
  dataset->num_total_features_ = *(reinterpret_cast<const int*>(mem_ptr));
326
327
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(dataset->num_total_features_));
Guolin Ke's avatar
Guolin Ke committed
328
  dataset->label_idx_ = *(reinterpret_cast<const int*>(mem_ptr));
329
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->label_idx_));
330
  dataset->max_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
331
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->max_bin_));
332
  dataset->bin_construct_sample_cnt_ = *(reinterpret_cast<const int*>(mem_ptr));
333
334
  mem_ptr += VirtualFileWriter::AlignedSize(
      sizeof(dataset->bin_construct_sample_cnt_));
335
  dataset->min_data_in_bin_ = *(reinterpret_cast<const int*>(mem_ptr));
336
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->min_data_in_bin_));
337
  dataset->use_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
338
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->use_missing_));
339
  dataset->zero_as_missing_ = *(reinterpret_cast<const bool*>(mem_ptr));
340
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->zero_as_missing_));
Guolin Ke's avatar
Guolin Ke committed
341
342
  const int* tmp_feature_map = reinterpret_cast<const int*>(mem_ptr);
  dataset->used_feature_map_.clear();
Guolin Ke's avatar
Guolin Ke committed
343
  for (int i = 0; i < dataset->num_total_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
344
345
    dataset->used_feature_map_.push_back(tmp_feature_map[i]);
  }
346
347
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int) *
                                            dataset->num_total_features_);
Guolin Ke's avatar
Guolin Ke committed
348
349
  // num_groups
  dataset->num_groups_ = *(reinterpret_cast<const int*>(mem_ptr));
350
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(dataset->num_groups_));
Guolin Ke's avatar
Guolin Ke committed
351
352
353
354
355
356
  // real_feature_idx_
  const int* tmp_ptr_real_feature_idx_ = reinterpret_cast<const int*>(mem_ptr);
  dataset->real_feature_idx_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->real_feature_idx_.push_back(tmp_ptr_real_feature_idx_[i]);
  }
357
358
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * dataset->num_features_);
Guolin Ke's avatar
Guolin Ke committed
359
360
361
362
363
364
  // feature2group
  const int* tmp_ptr_feature2group = reinterpret_cast<const int*>(mem_ptr);
  dataset->feature2group_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->feature2group_.push_back(tmp_ptr_feature2group[i]);
  }
365
366
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * dataset->num_features_);
Guolin Ke's avatar
Guolin Ke committed
367
368
369
370
371
372
  // feature2subfeature
  const int* tmp_ptr_feature2subfeature = reinterpret_cast<const int*>(mem_ptr);
  dataset->feature2subfeature_.clear();
  for (int i = 0; i < dataset->num_features_; ++i) {
    dataset->feature2subfeature_.push_back(tmp_ptr_feature2subfeature[i]);
  }
373
374
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * dataset->num_features_);
Guolin Ke's avatar
Guolin Ke committed
375
376
377
378
379
380
381
382
383
384
385
  // group_bin_boundaries
  const uint64_t* tmp_ptr_group_bin_boundaries = reinterpret_cast<const uint64_t*>(mem_ptr);
  dataset->group_bin_boundaries_.clear();
  for (int i = 0; i < dataset->num_groups_ + 1; ++i) {
    dataset->group_bin_boundaries_.push_back(tmp_ptr_group_bin_boundaries[i]);
  }
  mem_ptr += sizeof(uint64_t) * (dataset->num_groups_ + 1);

  // group_feature_start_
  const int* tmp_ptr_group_feature_start = reinterpret_cast<const int*>(mem_ptr);
  dataset->group_feature_start_.clear();
386
  for (int i = 0; i < dataset->num_groups_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
387
388
    dataset->group_feature_start_.push_back(tmp_ptr_group_feature_start[i]);
  }
389
390
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * (dataset->num_groups_));
Guolin Ke's avatar
Guolin Ke committed
391
392
393
394
395
396
397

  // group_feature_cnt_
  const int* tmp_ptr_group_feature_cnt = reinterpret_cast<const int*>(mem_ptr);
  dataset->group_feature_cnt_.clear();
  for (int i = 0; i < dataset->num_groups_; ++i) {
    dataset->group_feature_cnt_.push_back(tmp_ptr_group_feature_cnt[i]);
  }
398
399
  mem_ptr +=
      VirtualFileWriter::AlignedSize(sizeof(int) * (dataset->num_groups_));
Guolin Ke's avatar
Guolin Ke committed
400

Belinda Trotta's avatar
Belinda Trotta committed
401
  if (!config_.max_bin_by_feature.empty()) {
402
403
    CHECK_EQ(static_cast<size_t>(dataset->num_total_features_), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
Belinda Trotta's avatar
Belinda Trotta committed
404
405
406
407
408
409
410
411
412
    dataset->max_bin_by_feature_.resize(dataset->num_total_features_);
    dataset->max_bin_by_feature_.assign(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end());
  } else {
    const int32_t* tmp_ptr_max_bin_by_feature = reinterpret_cast<const int32_t*>(mem_ptr);
    dataset->max_bin_by_feature_.clear();
    for (int i = 0; i < dataset->num_total_features_; ++i) {
      dataset->max_bin_by_feature_.push_back(tmp_ptr_max_bin_by_feature[i]);
    }
  }
413
414
  mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int32_t) *
                                            (dataset->num_total_features_));
Belinda Trotta's avatar
Belinda Trotta committed
415
416
417
418
  if (ArrayArgs<int32_t>::CheckAll(dataset->max_bin_by_feature_, -1)) {
    dataset->max_bin_by_feature_.clear();
  }

Guolin Ke's avatar
Guolin Ke committed
419
  // get feature names
Guolin Ke's avatar
Guolin Ke committed
420
  dataset->feature_names_.clear();
Guolin Ke's avatar
Guolin Ke committed
421
422
423
  // write feature names
  for (int i = 0; i < dataset->num_total_features_; ++i) {
    int str_len = *(reinterpret_cast<const int*>(mem_ptr));
424
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int));
Guolin Ke's avatar
Guolin Ke committed
425
    std::stringstream str_buf;
426
    auto tmp_arr = reinterpret_cast<const char*>(mem_ptr);
Guolin Ke's avatar
Guolin Ke committed
427
    for (int j = 0; j < str_len; ++j) {
428
      char tmp_char = tmp_arr[j];
Guolin Ke's avatar
Guolin Ke committed
429
430
      str_buf << tmp_char;
    }
431
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(char) * str_len);
Guolin Ke's avatar
Guolin Ke committed
432
    dataset->feature_names_.emplace_back(str_buf.str());
Guolin Ke's avatar
Guolin Ke committed
433
  }
434
435
436
437
  // get forced_bin_bounds_
  dataset->forced_bin_bounds_ = std::vector<std::vector<double>>(dataset->num_total_features_, std::vector<double>());
  for (int i = 0; i < dataset->num_total_features_; ++i) {
    int num_bounds = *(reinterpret_cast<const int*>(mem_ptr));
438
    mem_ptr += VirtualFileWriter::AlignedSize(sizeof(int));
439
    dataset->forced_bin_bounds_[i] = std::vector<double>();
440
441
    const double* tmp_ptr_forced_bounds =
        reinterpret_cast<const double*>(mem_ptr);
442
443
444
445
446
447
    for (int j = 0; j < num_bounds; ++j) {
      double bound = tmp_ptr_forced_bounds[j];
      dataset->forced_bin_bounds_[i].push_back(bound);
    }
    mem_ptr += num_bounds * sizeof(double);
  }
Guolin Ke's avatar
Guolin Ke committed
448
449

  // read size of meta data
450
  read_cnt = reader->Read(buffer.data(), sizeof(size_t));
Guolin Ke's avatar
Guolin Ke committed
451

452
  if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
453
454
455
    Log::Fatal("Binary file error: meta data has the wrong size");
  }

Guolin Ke's avatar
Guolin Ke committed
456
  size_t size_of_metadata = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
457
458
459
460

  // re-allocate space if not enough
  if (size_of_metadata > buffer_size) {
    buffer_size = size_of_metadata;
Guolin Ke's avatar
Guolin Ke committed
461
    buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
462
463
  }
  //  read meta data
464
  read_cnt = reader->Read(buffer.data(), size_of_metadata);
Guolin Ke's avatar
Guolin Ke committed
465
466
467
468
469

  if (read_cnt != size_of_metadata) {
    Log::Fatal("Binary file error: meta data is incorrect");
  }
  // load meta data
Guolin Ke's avatar
Guolin Ke committed
470
  dataset->metadata_.LoadFromMemory(buffer.data());
Guolin Ke's avatar
Guolin Ke committed
471

472
473
  *num_global_data = dataset->num_data_;
  used_data_indices->clear();
Guolin Ke's avatar
Guolin Ke committed
474
  // sample local used data if need to partition
Guolin Ke's avatar
Guolin Ke committed
475
  if (num_machines > 1 && !config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
476
477
478
479
    const data_size_t* query_boundaries = dataset->metadata_.query_boundaries();
    if (query_boundaries == nullptr) {
      // if not contain query file, minimal sample unit is one record
      for (data_size_t i = 0; i < dataset->num_data_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
480
        if (random_.NextShort(0, num_machines) == rank) {
481
          used_data_indices->push_back(i);
Guolin Ke's avatar
Guolin Ke committed
482
483
484
485
486
487
488
489
490
        }
      }
    } else {
      // if contain query file, minimal sample unit is one query
      data_size_t num_queries = dataset->metadata_.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      for (data_size_t i = 0; i < dataset->num_data_; ++i) {
        if (qid >= num_queries) {
491
492
          Log::Fatal("Current query exceeds the range of the query file,\n"
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
493
494
495
496
        }
        if (i >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
497
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
498
499
500
501
502
            is_query_used = true;
          }
          ++qid;
        }
        if (is_query_used) {
503
          used_data_indices->push_back(i);
Guolin Ke's avatar
Guolin Ke committed
504
505
506
        }
      }
    }
507
    dataset->num_data_ = static_cast<data_size_t>((*used_data_indices).size());
Guolin Ke's avatar
Guolin Ke committed
508
  }
509
  dataset->metadata_.PartitionLabel(*used_data_indices);
Guolin Ke's avatar
Guolin Ke committed
510
  // read feature data
Guolin Ke's avatar
Guolin Ke committed
511
  for (int i = 0; i < dataset->num_groups_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
512
    // read feature size
513
514
    read_cnt = reader->Read(buffer.data(), sizeof(size_t));
    if (read_cnt != sizeof(size_t)) {
Guolin Ke's avatar
Guolin Ke committed
515
516
      Log::Fatal("Binary file error: feature %d has the wrong size", i);
    }
Guolin Ke's avatar
Guolin Ke committed
517
    size_t size_of_feature = *(reinterpret_cast<size_t*>(buffer.data()));
Guolin Ke's avatar
Guolin Ke committed
518
519
520
    // re-allocate space if not enough
    if (size_of_feature > buffer_size) {
      buffer_size = size_of_feature;
Guolin Ke's avatar
Guolin Ke committed
521
      buffer.resize(buffer_size);
Guolin Ke's avatar
Guolin Ke committed
522
523
    }

524
    read_cnt = reader->Read(buffer.data(), size_of_feature);
Guolin Ke's avatar
Guolin Ke committed
525
526
527
528

    if (read_cnt != size_of_feature) {
      Log::Fatal("Binary file error: feature %d is incorrect, read count: %d", i, read_cnt);
    }
Guolin Ke's avatar
Guolin Ke committed
529
    dataset->feature_groups_.emplace_back(std::unique_ptr<FeatureGroup>(
530
531
      new FeatureGroup(buffer.data(),
                       *num_global_data,
532
                       *used_data_indices)));
Guolin Ke's avatar
Guolin Ke committed
533
  }
Guolin Ke's avatar
Guolin Ke committed
534
  dataset->feature_groups_.shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
535
  dataset->is_finish_load_ = true;
Guolin Ke's avatar
Guolin Ke committed
536
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
537
538
}

539

540
541
542
Dataset* DatasetLoader::ConstructFromSampleData(double** sample_values,
                                                int** sample_indices, int num_col, const int* num_per_col,
                                                size_t total_sample_size, data_size_t num_data) {
543
544
545
546
547
  int num_total_features = num_col;
  if (Network::num_machines() > 1) {
    num_total_features = Network::GlobalSyncUpByMax(num_total_features);
  }
  std::vector<std::unique_ptr<BinMapper>> bin_mappers(num_total_features);
548
549
  // fill feature_names_ if not header
  if (feature_names_.empty()) {
550
    for (int i = 0; i < num_col; ++i) {
551
552
553
554
555
      std::stringstream str_buf;
      str_buf << "Column_" << i;
      feature_names_.push_back(str_buf.str());
    }
  }
Belinda Trotta's avatar
Belinda Trotta committed
556
  if (!config_.max_bin_by_feature.empty()) {
557
558
    CHECK_EQ(static_cast<size_t>(num_col), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
Belinda Trotta's avatar
Belinda Trotta committed
559
  }
560
561
562
563
564

  // get forced split
  std::string forced_bins_path = config_.forcedbins_filename;
  std::vector<std::vector<double>> forced_bin_bounds = DatasetLoader::GetForcedBins(forced_bins_path, num_col, categorical_features_);

Guolin Ke's avatar
Guolin Ke committed
565
  const data_size_t filter_cnt = static_cast<data_size_t>(
Guolin Ke's avatar
Guolin Ke committed
566
    static_cast<double>(config_.min_data_in_leaf * total_sample_size) / num_data);
567
568
569
570
571
572
573
574
575
576
577
578
579
  if (Network::num_machines() == 1) {
    // if only one machine, find bin locally
    OMP_INIT_EX();
    #pragma omp parallel for schedule(guided)
    for (int i = 0; i < num_col; ++i) {
      OMP_LOOP_EX_BEGIN();
      if (ignore_features_.count(i) > 0) {
        bin_mappers[i] = nullptr;
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(i)) {
        bin_type = BinType::CategoricalBin;
580
581
582
583
        bool feat_is_unconstrained = ((config_.monotone_constraints.size() == 0) || (config_.monotone_constraints[i] == 0));
        if (!feat_is_unconstrained) {
            Log::Fatal("The output cannot be monotone with respect to categorical features");
        }
584
585
      }
      bin_mappers[i].reset(new BinMapper());
Belinda Trotta's avatar
Belinda Trotta committed
586
587
      if (config_.max_bin_by_feature.empty()) {
        bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
588
                                config_.max_bin, config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter,
589
590
                                bin_type, config_.use_missing, config_.zero_as_missing,
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
591
592
      } else {
        bin_mappers[i]->FindBin(sample_values[i], num_per_col[i], total_sample_size,
593
                                config_.max_bin_by_feature[i], config_.min_data_in_bin,
594
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
595
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
596
      }
597
598
599
600
601
602
603
604
605
606
607
608
      OMP_LOOP_EX_END();
    }
    OMP_THROW_EX();
  } else {
    // if have multi-machines, need to find bin distributed
    // different machines will find bin for different features
    int num_machines = Network::num_machines();
    int rank = Network::rank();
    // start and len will store the process feature indices for different machines
    // machine i will find bins for features in [ start[i], start[i] + len[i] )
    std::vector<int> start(num_machines);
    std::vector<int> len(num_machines);
609
    int step = (num_total_features + num_machines - 1) / num_machines;
610
611
612
613
    if (step < 1) { step = 1; }

    start[0] = 0;
    for (int i = 0; i < num_machines - 1; ++i) {
614
      len[i] = std::min(step, num_total_features - start[i]);
615
616
      start[i + 1] = start[i] + len[i];
    }
617
    len[num_machines - 1] = num_total_features - start[num_machines - 1];
618
619
620
621
622
623
624
625
626
627
628
629
    OMP_INIT_EX();
    #pragma omp parallel for schedule(guided)
    for (int i = 0; i < len[rank]; ++i) {
      OMP_LOOP_EX_BEGIN();
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(start[rank] + i)) {
        bin_type = BinType::CategoricalBin;
      }
      bin_mappers[i].reset(new BinMapper());
630
631
632
      if (num_col <= start[rank] + i) {
        continue;
      }
Belinda Trotta's avatar
Belinda Trotta committed
633
      if (config_.max_bin_by_feature.empty()) {
634
635
        bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
                                total_sample_size, config_.max_bin, config_.min_data_in_bin,
636
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
637
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
638
      } else {
639
640
        bin_mappers[i]->FindBin(sample_values[start[rank] + i], num_per_col[start[rank] + i],
                                total_sample_size, config_.max_bin_by_feature[start[rank] + i],
641
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
642
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
643
      }
644
645
      OMP_LOOP_EX_END();
    }
Guolin Ke's avatar
Guolin Ke committed
646
    OMP_THROW_EX();
647
    comm_size_t self_buf_size = 0;
648
    for (int i = 0; i < len[rank]; ++i) {
649
650
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
651
      }
652
      self_buf_size += static_cast<comm_size_t>(bin_mappers[i]->SizesInByte());
Guolin Ke's avatar
Guolin Ke committed
653
    }
654
655
    std::vector<char> input_buffer(self_buf_size);
    auto cp_ptr = input_buffer.data();
656
657
658
659
    for (int i = 0; i < len[rank]; ++i) {
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
660
661
      bin_mappers[i]->CopyTo(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
662
663
664
      // free
      bin_mappers[i].reset(nullptr);
    }
665
666
667
668
    std::vector<comm_size_t> size_len = Network::GlobalArray(self_buf_size);
    std::vector<comm_size_t> size_start(num_machines, 0);
    for (int i = 1; i < num_machines; ++i) {
      size_start[i] = size_start[i - 1] + size_len[i - 1];
669
    }
670
671
    comm_size_t total_buffer_size = size_start[num_machines - 1] + size_len[num_machines - 1];
    std::vector<char> output_buffer(total_buffer_size);
672
    // gather global feature bin mappers
673
674
    Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), total_buffer_size);
    cp_ptr = output_buffer.data();
675
    // restore features bins from buffer
676
    for (int i = 0; i < num_total_features; ++i) {
677
678
679
680
681
      if (ignore_features_.count(i) > 0) {
        bin_mappers[i] = nullptr;
        continue;
      }
      bin_mappers[i].reset(new BinMapper());
682
683
      bin_mappers[i]->CopyFrom(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
684
    }
Guolin Ke's avatar
Guolin Ke committed
685
  }
Guolin Ke's avatar
Guolin Ke committed
686
  auto dataset = std::unique_ptr<Dataset>(new Dataset(num_data));
Guolin Ke's avatar
Guolin Ke committed
687
  dataset->Construct(&bin_mappers, num_total_features, forced_bin_bounds, sample_indices, sample_values, num_per_col, num_col, total_sample_size, config_);
688
  dataset->set_feature_names(feature_names_);
Guolin Ke's avatar
Guolin Ke committed
689
  return dataset.release();
Guolin Ke's avatar
Guolin Ke committed
690
}
Guolin Ke's avatar
Guolin Ke committed
691
692
693
694
695
696


// ---- private functions ----

void DatasetLoader::CheckDataset(const Dataset* dataset) {
  if (dataset->num_data_ <= 0) {
Guolin Ke's avatar
Guolin Ke committed
697
    Log::Fatal("Data file %s is empty", dataset->data_filename_.c_str());
Guolin Ke's avatar
Guolin Ke committed
698
  }
699
700
  if (dataset->feature_names_.size() != static_cast<size_t>(dataset->num_total_features_)) {
    Log::Fatal("Size of feature name error, should be %d, got %d", dataset->num_total_features_,
701
               static_cast<int>(dataset->feature_names_.size()));
702
  }
Guolin Ke's avatar
Guolin Ke committed
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
  bool is_feature_order_by_group = true;
  int last_group = -1;
  int last_sub_feature = -1;
  // if features are ordered, not need to use hist_buf
  for (int i = 0; i < dataset->num_features_; ++i) {
    int group = dataset->feature2group_[i];
    int sub_feature = dataset->feature2subfeature_[i];
    if (group < last_group) {
      is_feature_order_by_group = false;
    } else if (group == last_group) {
      if (sub_feature <= last_sub_feature) {
        is_feature_order_by_group = false;
        break;
      }
    }
    last_group = group;
    last_sub_feature = sub_feature;
  }
  if (!is_feature_order_by_group) {
722
    Log::Fatal("Features in dataset should be ordered by group");
Guolin Ke's avatar
Guolin Ke committed
723
  }
Guolin Ke's avatar
Guolin Ke committed
724
725
726
}

std::vector<std::string> DatasetLoader::LoadTextDataToMemory(const char* filename, const Metadata& metadata,
727
728
                                                             int rank, int num_machines, int* num_global_data,
                                                             std::vector<data_size_t>* used_data_indices) {
729
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
730
  used_data_indices->clear();
Guolin Ke's avatar
Guolin Ke committed
731
  if (num_machines == 1 || config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
732
733
734
735
736
737
738
739
740
    // read all lines
    *num_global_data = text_reader.ReadAllLines();
  } else {  // need partition data
            // get query data
    const data_size_t* query_boundaries = metadata.query_boundaries();

    if (query_boundaries == nullptr) {
      // if not contain query data, minimal sample unit is one record
      *num_global_data = text_reader.ReadAndFilterLines([this, rank, num_machines](data_size_t) {
Guolin Ke's avatar
Guolin Ke committed
741
        if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
742
743
744
745
746
747
748
749
750
751
752
753
754
755
          return true;
        } else {
          return false;
        }
      }, used_data_indices);
    } else {
      // if contain query data, minimal sample unit is one query
      data_size_t num_queries = metadata.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      *num_global_data = text_reader.ReadAndFilterLines(
        [this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries]
      (data_size_t line_idx) {
        if (qid >= num_queries) {
756
757
          Log::Fatal("Current query exceeds the range of the query file,\n"
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
758
759
760
761
        }
        if (line_idx >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
762
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
763
764
765
766
767
768
769
770
771
772
773
774
            is_query_used = true;
          }
          ++qid;
        }
        return is_query_used;
      }, used_data_indices);
    }
  }
  return std::move(text_reader.Lines());
}

std::vector<std::string> DatasetLoader::SampleTextDataFromMemory(const std::vector<std::string>& data) {
Guolin Ke's avatar
Guolin Ke committed
775
  int sample_cnt = config_.bin_construct_sample_cnt;
776
777
  if (static_cast<size_t>(sample_cnt) > data.size()) {
    sample_cnt = static_cast<int>(data.size());
778
  }
779
  auto sample_indices = random_.Sample(static_cast<int>(data.size()), sample_cnt);
Guolin Ke's avatar
Guolin Ke committed
780
  std::vector<std::string> out(sample_indices.size());
Guolin Ke's avatar
Guolin Ke committed
781
782
  for (size_t i = 0; i < sample_indices.size(); ++i) {
    const size_t idx = sample_indices[i];
Guolin Ke's avatar
Guolin Ke committed
783
    out[i] = data[idx];
Guolin Ke's avatar
Guolin Ke committed
784
785
786
787
  }
  return out;
}

788
789
790
std::vector<std::string> DatasetLoader::SampleTextDataFromFile(const char* filename, const Metadata& metadata,
                                                               int rank, int num_machines, int* num_global_data,
                                                               std::vector<data_size_t>* used_data_indices) {
Guolin Ke's avatar
Guolin Ke committed
791
  const data_size_t sample_cnt = static_cast<data_size_t>(config_.bin_construct_sample_cnt);
792
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
793
  std::vector<std::string> out_data;
Guolin Ke's avatar
Guolin Ke committed
794
  if (num_machines == 1 || config_.pre_partition) {
Guolin Ke's avatar
Guolin Ke committed
795
    *num_global_data = static_cast<data_size_t>(text_reader.SampleFromFile(&random_, sample_cnt, &out_data));
Guolin Ke's avatar
Guolin Ke committed
796
797
798
799
800
801
802
  } else {  // need partition data
            // get query data
    const data_size_t* query_boundaries = metadata.query_boundaries();
    if (query_boundaries == nullptr) {
      // if not contain query file, minimal sample unit is one record
      *num_global_data = text_reader.SampleAndFilterFromFile([this, rank, num_machines]
      (data_size_t) {
Guolin Ke's avatar
Guolin Ke committed
803
        if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
804
805
806
807
          return true;
        } else {
          return false;
        }
Guolin Ke's avatar
Guolin Ke committed
808
      }, used_data_indices, &random_, sample_cnt, &out_data);
Guolin Ke's avatar
Guolin Ke committed
809
810
811
812
813
814
815
816
817
    } else {
      // if contain query file, minimal sample unit is one query
      data_size_t num_queries = metadata.num_queries();
      data_size_t qid = -1;
      bool is_query_used = false;
      *num_global_data = text_reader.SampleAndFilterFromFile(
        [this, rank, num_machines, &qid, &query_boundaries, &is_query_used, num_queries]
      (data_size_t line_idx) {
        if (qid >= num_queries) {
818
819
          Log::Fatal("Query id exceeds the range of the query file, "
                     "please ensure the query file is correct");
Guolin Ke's avatar
Guolin Ke committed
820
821
822
823
        }
        if (line_idx >= query_boundaries[qid + 1]) {
          // if is new query
          is_query_used = false;
Guolin Ke's avatar
Guolin Ke committed
824
          if (random_.NextShort(0, num_machines) == rank) {
Guolin Ke's avatar
Guolin Ke committed
825
826
827
828
829
            is_query_used = true;
          }
          ++qid;
        }
        return is_query_used;
Guolin Ke's avatar
Guolin Ke committed
830
      }, used_data_indices, &random_, sample_cnt, &out_data);
Guolin Ke's avatar
Guolin Ke committed
831
832
833
834
835
    }
  }
  return out_data;
}

836
837
838
void DatasetLoader::ConstructBinMappersFromTextData(int rank, int num_machines,
                                                    const std::vector<std::string>& sample_data,
                                                    const Parser* parser, Dataset* dataset) {
Guolin Ke's avatar
Guolin Ke committed
839
  std::vector<std::vector<double>> sample_values;
Guolin Ke's avatar
Guolin Ke committed
840
  std::vector<std::vector<int>> sample_indices;
Guolin Ke's avatar
Guolin Ke committed
841
842
  std::vector<std::pair<int, double>> oneline_features;
  double label;
Guolin Ke's avatar
Guolin Ke committed
843
  for (int i = 0; i < static_cast<int>(sample_data.size()); ++i) {
Guolin Ke's avatar
Guolin Ke committed
844
845
846
847
    oneline_features.clear();
    // parse features
    parser->ParseOneLine(sample_data[i].c_str(), &oneline_features, &label);
    for (std::pair<int, double>& inner_data : oneline_features) {
848
      if (static_cast<size_t>(inner_data.first) >= sample_values.size()) {
Guolin Ke's avatar
Guolin Ke committed
849
850
        sample_values.resize(inner_data.first + 1);
        sample_indices.resize(inner_data.first + 1);
851
      }
Guolin Ke's avatar
Guolin Ke committed
852
      if (std::fabs(inner_data.second) > kZeroThreshold || std::isnan(inner_data.second)) {
Guolin Ke's avatar
Guolin Ke committed
853
854
        sample_values[inner_data.first].emplace_back(inner_data.second);
        sample_indices[inner_data.first].emplace_back(i);
Guolin Ke's avatar
Guolin Ke committed
855
856
857
858
      }
    }
  }

Guolin Ke's avatar
Guolin Ke committed
859
  dataset->feature_groups_.clear();
860
861
862
863
864
  dataset->num_total_features_ = std::max(static_cast<int>(sample_values.size()), parser->NumFeatures());
  if (num_machines > 1) {
    dataset->num_total_features_ = Network::GlobalSyncUpByMax(dataset->num_total_features_);
  }
  if (!feature_names_.empty()) {
865
    CHECK_EQ(dataset->num_total_features_, static_cast<int>(feature_names_.size()));
866
  }
Guolin Ke's avatar
Guolin Ke committed
867

Belinda Trotta's avatar
Belinda Trotta committed
868
  if (!config_.max_bin_by_feature.empty()) {
869
870
    CHECK_EQ(static_cast<size_t>(dataset->num_total_features_), config_.max_bin_by_feature.size());
    CHECK_GT(*(std::min_element(config_.max_bin_by_feature.begin(), config_.max_bin_by_feature.end())), 1);
Belinda Trotta's avatar
Belinda Trotta committed
871
872
  }

873
874
  // get forced split
  std::string forced_bins_path = config_.forcedbins_filename;
875
876
  std::vector<std::vector<double>> forced_bin_bounds = DatasetLoader::GetForcedBins(forced_bins_path,
                                                                                    dataset->num_total_features_,
877
878
                                                                                    categorical_features_);

Guolin Ke's avatar
Guolin Ke committed
879
880
881
882
883
884
  // check the range of label_idx, weight_idx and group_idx
  CHECK(label_idx_ >= 0 && label_idx_ <= dataset->num_total_features_);
  CHECK(weight_idx_ < 0 || weight_idx_ < dataset->num_total_features_);
  CHECK(group_idx_ < 0 || group_idx_ < dataset->num_total_features_);

  // fill feature_names_ if not header
Guolin Ke's avatar
Guolin Ke committed
885
  if (feature_names_.empty()) {
Guolin Ke's avatar
Guolin Ke committed
886
887
888
889
890
891
    for (int i = 0; i < dataset->num_total_features_; ++i) {
      std::stringstream str_buf;
      str_buf << "Column_" << i;
      feature_names_.push_back(str_buf.str());
    }
  }
892
  dataset->set_feature_names(feature_names_);
Guolin Ke's avatar
Guolin Ke committed
893
  std::vector<std::unique_ptr<BinMapper>> bin_mappers(dataset->num_total_features_);
Guolin Ke's avatar
Guolin Ke committed
894
  const data_size_t filter_cnt = static_cast<data_size_t>(
Guolin Ke's avatar
Guolin Ke committed
895
    static_cast<double>(config_.min_data_in_leaf* sample_data.size()) / dataset->num_data_);
Guolin Ke's avatar
Guolin Ke committed
896
897
898
  // start find bins
  if (num_machines == 1) {
    // if only one machine, find bin locally
899
    OMP_INIT_EX();
900
    #pragma omp parallel for schedule(guided)
Guolin Ke's avatar
Guolin Ke committed
901
    for (int i = 0; i < static_cast<int>(sample_values.size()); ++i) {
902
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
903
      if (ignore_features_.count(i) > 0) {
Guolin Ke's avatar
Guolin Ke committed
904
        bin_mappers[i] = nullptr;
Guolin Ke's avatar
Guolin Ke committed
905
906
        continue;
      }
907
908
909
910
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(i)) {
        bin_type = BinType::CategoricalBin;
      }
Guolin Ke's avatar
Guolin Ke committed
911
      bin_mappers[i].reset(new BinMapper());
Belinda Trotta's avatar
Belinda Trotta committed
912
913
      if (config_.max_bin_by_feature.empty()) {
        bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
914
                                sample_data.size(), config_.max_bin, config_.min_data_in_bin,
915
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
916
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
917
918
      } else {
        bin_mappers[i]->FindBin(sample_values[i].data(), static_cast<int>(sample_values[i].size()),
919
                                sample_data.size(), config_.max_bin_by_feature[i],
920
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing,
921
                                config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
922
      }
923
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
924
    }
925
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
926
927
  } else {
    // start and len will store the process feature indices for different machines
928
    // machine i will find bins for features in [ start[i], start[i] + len[i] )
Guolin Ke's avatar
Guolin Ke committed
929
930
    std::vector<int> start(num_machines);
    std::vector<int> len(num_machines);
931
    int step = (dataset->num_total_features_ + num_machines - 1) / num_machines;
Guolin Ke's avatar
Guolin Ke committed
932
933
934
935
    if (step < 1) { step = 1; }

    start[0] = 0;
    for (int i = 0; i < num_machines - 1; ++i) {
936
      len[i] = std::min(step, dataset->num_total_features_ - start[i]);
Guolin Ke's avatar
Guolin Ke committed
937
938
      start[i + 1] = start[i] + len[i];
    }
939
    len[num_machines - 1] = dataset->num_total_features_ - start[num_machines - 1];
940
    OMP_INIT_EX();
941
    #pragma omp parallel for schedule(guided)
942
    for (int i = 0; i < len[rank]; ++i) {
943
      OMP_LOOP_EX_BEGIN();
944
945
946
947
948
949
950
951
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
      BinType bin_type = BinType::NumericalBin;
      if (categorical_features_.count(start[rank] + i)) {
        bin_type = BinType::CategoricalBin;
      }
      bin_mappers[i].reset(new BinMapper());
Nikita Titov's avatar
Nikita Titov committed
952
      if (static_cast<int>(sample_values.size()) <= start[rank] + i) {
953
954
        continue;
      }
Belinda Trotta's avatar
Belinda Trotta committed
955
      if (config_.max_bin_by_feature.empty()) {
956
        bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
Belinda Trotta's avatar
Belinda Trotta committed
957
                                static_cast<int>(sample_values[start[rank] + i].size()),
958
                                sample_data.size(), config_.max_bin, config_.min_data_in_bin,
959
                                filter_cnt, config_.feature_pre_filter, bin_type, config_.use_missing, config_.zero_as_missing,
960
                                forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
961
      } else {
962
        bin_mappers[i]->FindBin(sample_values[start[rank] + i].data(),
Belinda Trotta's avatar
Belinda Trotta committed
963
                                static_cast<int>(sample_values[start[rank] + i].size()),
964
                                sample_data.size(), config_.max_bin_by_feature[i],
965
                                config_.min_data_in_bin, filter_cnt, config_.feature_pre_filter, bin_type,
966
                                config_.use_missing, config_.zero_as_missing, forced_bin_bounds[i]);
Belinda Trotta's avatar
Belinda Trotta committed
967
      }
968
      OMP_LOOP_EX_END();
969
    }
970
    OMP_THROW_EX();
971
    comm_size_t self_buf_size = 0;
Guolin Ke's avatar
Guolin Ke committed
972
    for (int i = 0; i < len[rank]; ++i) {
973
974
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
Guolin Ke's avatar
Guolin Ke committed
975
      }
976
      self_buf_size += static_cast<comm_size_t>(bin_mappers[i]->SizesInByte());
Guolin Ke's avatar
Guolin Ke committed
977
    }
978
979
    std::vector<char> input_buffer(self_buf_size);
    auto cp_ptr = input_buffer.data();
Guolin Ke's avatar
Guolin Ke committed
980
    for (int i = 0; i < len[rank]; ++i) {
Guolin Ke's avatar
Guolin Ke committed
981
982
983
      if (ignore_features_.count(start[rank] + i) > 0) {
        continue;
      }
984
985
      bin_mappers[i]->CopyTo(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
986
987
      // free
      bin_mappers[i].reset(nullptr);
Guolin Ke's avatar
Guolin Ke committed
988
    }
989
990
991
992
    std::vector<comm_size_t> size_len = Network::GlobalArray(self_buf_size);
    std::vector<comm_size_t> size_start(num_machines, 0);
    for (int i = 1; i < num_machines; ++i) {
      size_start[i] = size_start[i - 1] + size_len[i - 1];
Guolin Ke's avatar
Guolin Ke committed
993
    }
994
995
    comm_size_t total_buffer_size = size_start[num_machines - 1] + size_len[num_machines - 1];
    std::vector<char> output_buffer(total_buffer_size);
Guolin Ke's avatar
Guolin Ke committed
996
    // gather global feature bin mappers
997
998
    Network::Allgather(input_buffer.data(), size_start.data(), size_len.data(), output_buffer.data(), total_buffer_size);
    cp_ptr = output_buffer.data();
Guolin Ke's avatar
Guolin Ke committed
999
    // restore features bins from buffer
1000
    for (int i = 0; i < dataset->num_total_features_; ++i) {
Guolin Ke's avatar
Guolin Ke committed
1001
      if (ignore_features_.count(i) > 0) {
Guolin Ke's avatar
Guolin Ke committed
1002
        bin_mappers[i] = nullptr;
Guolin Ke's avatar
Guolin Ke committed
1003
1004
        continue;
      }
Guolin Ke's avatar
Guolin Ke committed
1005
      bin_mappers[i].reset(new BinMapper());
1006
1007
      bin_mappers[i]->CopyFrom(cp_ptr);
      cp_ptr += bin_mappers[i]->SizesInByte();
Guolin Ke's avatar
Guolin Ke committed
1008
1009
    }
  }
1010
  dataset->Construct(&bin_mappers, dataset->num_total_features_, forced_bin_bounds, Common::Vector2Ptr<int>(&sample_indices).data(),
Guolin Ke's avatar
Guolin Ke committed
1011
                     Common::Vector2Ptr<double>(&sample_values).data(),
1012
                     Common::VectorSize<int>(sample_indices).data(), static_cast<int>(sample_indices.size()), sample_data.size(), config_);
Guolin Ke's avatar
Guolin Ke committed
1013
1014
1015
}

/*! \brief Extract local features from memory */
Guolin Ke's avatar
Guolin Ke committed
1016
void DatasetLoader::ExtractFeaturesFromMemory(std::vector<std::string>* text_data, const Parser* parser, Dataset* dataset) {
Guolin Ke's avatar
Guolin Ke committed
1017
1018
  std::vector<std::pair<int, double>> oneline_features;
  double tmp_label = 0.0f;
Guolin Ke's avatar
Guolin Ke committed
1019
  auto& ref_text_data = *text_data;
Guolin Ke's avatar
Guolin Ke committed
1020
  if (predict_fun_ == nullptr) {
1021
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1022
    // if doesn't need to prediction with initial model
1023
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
Guolin Ke's avatar
Guolin Ke committed
1024
    for (data_size_t i = 0; i < dataset->num_data_; ++i) {
1025
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1026
1027
1028
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
Guolin Ke's avatar
Guolin Ke committed
1029
      parser->ParseOneLine(ref_text_data[i].c_str(), &oneline_features, &tmp_label);
Guolin Ke's avatar
Guolin Ke committed
1030
      // set label
1031
      dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1032
      // free processed line:
Guolin Ke's avatar
Guolin Ke committed
1033
      ref_text_data[i].clear();
Guolin Ke's avatar
Guolin Ke committed
1034
1035
      // shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
      // text_reader_->Lines()[i].shrink_to_fit();
Guolin Ke's avatar
Guolin Ke committed
1036
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1037
1038
      // push data
      for (auto& inner_data : oneline_features) {
1039
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1040
1041
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1042
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1043
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1044
1045
1046
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
          dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
Guolin Ke's avatar
Guolin Ke committed
1047
1048
        } else {
          if (inner_data.first == weight_idx_) {
1049
            dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1050
1051
1052
1053
1054
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1055
      dataset->FinishOneRow(tid, i, is_feature_added);
1056
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1057
    }
1058
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1059
  } else {
1060
    OMP_INIT_EX();
Guolin Ke's avatar
Guolin Ke committed
1061
    // if need to prediction with initial model
1062
    std::vector<double> init_score(dataset->num_data_ * num_class_);
1063
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
Guolin Ke's avatar
Guolin Ke committed
1064
    for (data_size_t i = 0; i < dataset->num_data_; ++i) {
1065
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1066
1067
1068
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
Guolin Ke's avatar
Guolin Ke committed
1069
      parser->ParseOneLine(ref_text_data[i].c_str(), &oneline_features, &tmp_label);
Guolin Ke's avatar
Guolin Ke committed
1070
      // set initial score
Guolin Ke's avatar
Guolin Ke committed
1071
1072
      std::vector<double> oneline_init_score(num_class_);
      predict_fun_(oneline_features, oneline_init_score.data());
1073
      for (int k = 0; k < num_class_; ++k) {
1074
        init_score[k * dataset->num_data_ + i] = static_cast<double>(oneline_init_score[k]);
Guolin Ke's avatar
Guolin Ke committed
1075
1076
      }
      // set label
1077
      dataset->metadata_.SetLabelAt(i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1078
1079
1080
1081
1082
      // free processed line:
      text_data[i].clear();
      // shrink_to_fit will be very slow in linux, and seems not free memory, disable for now
      // text_reader_->Lines()[i].shrink_to_fit();
      // push data
Guolin Ke's avatar
Guolin Ke committed
1083
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1084
      for (auto& inner_data : oneline_features) {
1085
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1086
1087
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1088
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1089
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1090
1091
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
1092
          dataset->feature_groups_[group]->PushData(tid, sub_feature, i, inner_data.second);
Guolin Ke's avatar
Guolin Ke committed
1093
1094
        } else {
          if (inner_data.first == weight_idx_) {
1095
            dataset->metadata_.SetWeightAt(i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1096
1097
1098
1099
1100
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1101
      dataset->FinishOneRow(tid, i, is_feature_added);
1102
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1103
    }
1104
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1105
    // metadata_ will manage space of init_score
1106
    dataset->metadata_.SetInitScore(init_score.data(), dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1107
  }
Guolin Ke's avatar
Guolin Ke committed
1108
  dataset->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1109
  // text data can be free after loaded feature values
Guolin Ke's avatar
Guolin Ke committed
1110
  text_data->clear();
Guolin Ke's avatar
Guolin Ke committed
1111
1112
1113
}

/*! \brief Extract local features from file */
1114
1115
void DatasetLoader::ExtractFeaturesFromFile(const char* filename, const Parser* parser,
                                            const std::vector<data_size_t>& used_data_indices, Dataset* dataset) {
1116
  std::vector<double> init_score;
Guolin Ke's avatar
Guolin Ke committed
1117
  if (predict_fun_ != nullptr) {
1118
    init_score = std::vector<double>(dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1119
1120
1121
1122
1123
1124
  }
  std::function<void(data_size_t, const std::vector<std::string>&)> process_fun =
    [this, &init_score, &parser, &dataset]
  (data_size_t start_idx, const std::vector<std::string>& lines) {
    std::vector<std::pair<int, double>> oneline_features;
    double tmp_label = 0.0f;
1125
    OMP_INIT_EX();
1126
    #pragma omp parallel for schedule(static) private(oneline_features) firstprivate(tmp_label)
Guolin Ke's avatar
Guolin Ke committed
1127
    for (data_size_t i = 0; i < static_cast<data_size_t>(lines.size()); ++i) {
1128
      OMP_LOOP_EX_BEGIN();
Guolin Ke's avatar
Guolin Ke committed
1129
1130
1131
1132
1133
      const int tid = omp_get_thread_num();
      oneline_features.clear();
      // parser
      parser->ParseOneLine(lines[i].c_str(), &oneline_features, &tmp_label);
      // set initial score
Guolin Ke's avatar
Guolin Ke committed
1134
      if (!init_score.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1135
1136
        std::vector<double> oneline_init_score(num_class_);
        predict_fun_(oneline_features, oneline_init_score.data());
1137
        for (int k = 0; k < num_class_; ++k) {
1138
          init_score[k * dataset->num_data_ + start_idx + i] = static_cast<double>(oneline_init_score[k]);
Guolin Ke's avatar
Guolin Ke committed
1139
1140
1141
        }
      }
      // set label
1142
      dataset->metadata_.SetLabelAt(start_idx + i, static_cast<label_t>(tmp_label));
Guolin Ke's avatar
Guolin Ke committed
1143
      std::vector<bool> is_feature_added(dataset->num_features_, false);
Guolin Ke's avatar
Guolin Ke committed
1144
1145
      // push data
      for (auto& inner_data : oneline_features) {
1146
        if (inner_data.first >= dataset->num_total_features_) { continue; }
Guolin Ke's avatar
Guolin Ke committed
1147
1148
        int feature_idx = dataset->used_feature_map_[inner_data.first];
        if (feature_idx >= 0) {
Guolin Ke's avatar
Guolin Ke committed
1149
          is_feature_added[feature_idx] = true;
Guolin Ke's avatar
Guolin Ke committed
1150
          // if is used feature
Guolin Ke's avatar
Guolin Ke committed
1151
1152
1153
          int group = dataset->feature2group_[feature_idx];
          int sub_feature = dataset->feature2subfeature_[feature_idx];
          dataset->feature_groups_[group]->PushData(tid, sub_feature, start_idx + i, inner_data.second);
Guolin Ke's avatar
Guolin Ke committed
1154
1155
        } else {
          if (inner_data.first == weight_idx_) {
1156
            dataset->metadata_.SetWeightAt(start_idx + i, static_cast<label_t>(inner_data.second));
Guolin Ke's avatar
Guolin Ke committed
1157
1158
1159
1160
1161
          } else if (inner_data.first == group_idx_) {
            dataset->metadata_.SetQueryAt(start_idx + i, static_cast<data_size_t>(inner_data.second));
          }
        }
      }
Guolin Ke's avatar
Guolin Ke committed
1162
      dataset->FinishOneRow(tid, i, is_feature_added);
1163
      OMP_LOOP_EX_END();
Guolin Ke's avatar
Guolin Ke committed
1164
    }
1165
    OMP_THROW_EX();
Guolin Ke's avatar
Guolin Ke committed
1166
  };
1167
  TextReader<data_size_t> text_reader(filename, config_.header, config_.file_load_progress_interval_bytes);
Guolin Ke's avatar
Guolin Ke committed
1168
  if (!used_data_indices.empty()) {
Guolin Ke's avatar
Guolin Ke committed
1169
1170
1171
1172
1173
1174
1175
1176
    // only need part of data
    text_reader.ReadPartAndProcessParallel(used_data_indices, process_fun);
  } else {
    // need full data
    text_reader.ReadAllAndProcessParallel(process_fun);
  }

  // metadata_ will manage space of init_score
Guolin Ke's avatar
Guolin Ke committed
1177
  if (!init_score.empty()) {
1178
    dataset->metadata_.SetInitScore(init_score.data(), dataset->num_data_ * num_class_);
Guolin Ke's avatar
Guolin Ke committed
1179
  }
Guolin Ke's avatar
Guolin Ke committed
1180
  dataset->FinishLoad();
Guolin Ke's avatar
Guolin Ke committed
1181
1182
1183
}

/*! \brief Check can load from binary file */
1184
std::string DatasetLoader::CheckCanLoadFromBin(const char* filename) {
Guolin Ke's avatar
Guolin Ke committed
1185
1186
1187
  std::string bin_filename(filename);
  bin_filename.append(".bin");

1188
  auto reader = VirtualFileReader::Make(bin_filename.c_str());
Guolin Ke's avatar
Guolin Ke committed
1189

1190
  if (!reader->Init()) {
1191
    bin_filename = std::string(filename);
1192
1193
    reader = VirtualFileReader::Make(bin_filename.c_str());
    if (!reader->Init()) {
1194
      Log::Fatal("Cannot open data file %s", bin_filename.c_str());
1195
    }
1196
  }
1197
1198
1199
1200
1201

  size_t buffer_size = 256;
  auto buffer = std::vector<char>(buffer_size);
  // read size of token
  size_t size_of_token = std::strlen(Dataset::binary_file_token);
1202
  size_t read_cnt = reader->Read(buffer.data(), size_of_token);
1203
1204
  if (read_cnt == size_of_token
      && std::string(buffer.data()) == std::string(Dataset::binary_file_token)) {
1205
    return bin_filename;
Guolin Ke's avatar
Guolin Ke committed
1206
  } else {
1207
    return std::string();
Guolin Ke's avatar
Guolin Ke committed
1208
1209
1210
  }
}

1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223


std::vector<std::vector<double>> DatasetLoader::GetForcedBins(std::string forced_bins_path, int num_total_features,
                                                              const std::unordered_set<int>& categorical_features) {
  std::vector<std::vector<double>> forced_bins(num_total_features, std::vector<double>());
  if (forced_bins_path != "") {
    std::ifstream forced_bins_stream(forced_bins_path.c_str());
    if (forced_bins_stream.fail()) {
      Log::Warning("Could not open %s. Will ignore.", forced_bins_path.c_str());
    } else {
      std::stringstream buffer;
      buffer << forced_bins_stream.rdbuf();
      std::string err;
Guolin Ke's avatar
Guolin Ke committed
1224
      Json forced_bins_json = Json::parse(buffer.str(), &err);
1225
1226
1227
1228
      CHECK(forced_bins_json.is_array());
      std::vector<Json> forced_bins_arr = forced_bins_json.array_items();
      for (size_t i = 0; i < forced_bins_arr.size(); ++i) {
        int feature_num = forced_bins_arr[i]["feature"].int_value();
Nikita Titov's avatar
Nikita Titov committed
1229
        CHECK_LT(feature_num, num_total_features);
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
        if (categorical_features.count(feature_num)) {
          Log::Warning("Feature %d is categorical. Will ignore forced bins for this  feature.", feature_num);
        } else {
          std::vector<Json> bounds_arr = forced_bins_arr[i]["bin_upper_bound"].array_items();
          for (size_t j = 0; j < bounds_arr.size(); ++j) {
            forced_bins[feature_num].push_back(bounds_arr[j].number_value());
          }
        }
      }
      // remove duplicates
      for (int i = 0; i < num_total_features; ++i) {
        auto new_end = std::unique(forced_bins[i].begin(), forced_bins[i].end());
        forced_bins[i].erase(new_end, forced_bins[i].end());
      }
    }
  }
  return forced_bins;
}

1249
}  // namespace LightGBM